In [23]:
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
from torchvision.datasets.voc import VOCSegmentation
from SSP.process_voc import VOCSegmentationWithJointTransform, JointTransform
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
dataset = VOCSegmentationWithJointTransform(
    root='data',
    year='2012',
    image_set='train',
    download=True,
    joint_transform=JointTransform()
)

In [25]:
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [30]:
model = deeplabv3_mobilenet_v3_large(pretrained=True)

In [31]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 5
for epoch in tqdm(range(num_epochs)):
    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [07:00<00:00, 84.13s/it]


In [33]:
torch.save(model.state_dict(), "model_weights.pth")
