In [12]:
import torch
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torchvision.models.segmentation import fcn_resnet50

from data.pascal_voc_dataset import PascalVOCSegmentation
from data.utils import get_pascal_dataloader

In [13]:
! setenv CUDA_VISIBLE_DEVICES 0,1

In [14]:
LEARNING_RATE = 3e-4
BATCH_SIZE = 32
EPOCHS = 2
DATA_ROOT = "."
num_workers = 0

device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
train_dataset = PascalVOCSegmentation(
    root_dir=DATA_ROOT,
    split="train",
    input_size=513
)

val_dataset = PascalVOCSegmentation(
    root_dir=DATA_ROOT,
    split="val",
    input_size=513
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True
    
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False
)

In [16]:
len(train_dataloader)

45

In [17]:
model = fcn_resnet50(pretrained=True).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=-1) 

In [18]:
def compute_iou(outputs, targets):
    smooth = 1e-6
    preds = torch.argmax(outputs, dim=1)
    
    ious = []
    for cls in range(1, 21): 
        pred_inds = preds == cls
        target_inds = targets == cls
        
        intersection = (pred_inds & target_inds).float().sum()
        union = (pred_inds | target_inds).float().sum()
        
        if union.item() > 0:
            iou = (intersection + smooth) / (union + smooth)
            ious.append(iou.item())
    
    return np.mean(ious) if ious else 0

In [19]:
for epoch in tqdm(range(EPOCHS)):
    model.train()
    train_running_loss = 0
    for idx, img_mask in enumerate(tqdm(train_dataloader)):
        img = img_mask[0].float().to(device)
        mask = img_mask[1].long().to(device)

        y_pred = model(img)['out'] 
        optimizer.zero_grad()

        loss = criterion(y_pred, mask)
        train_running_loss += loss.item()
        
        loss.backward()
        optimizer.step()

    train_loss = train_running_loss / (idx + 1)

    model.eval()
    val_running_loss = 0
    val_iou = 0
    with torch.no_grad():
        for idx, img_mask in enumerate(tqdm(val_dataloader)):
            img = img_mask[0].float().to(device)
            mask = img_mask[1].long().to(device)
            
            y_pred = model(img)['out'] 
            loss = criterion(y_pred, mask)
            
            batch_iou = compute_iou(y_pred, mask)
            val_iou += batch_iou

            val_running_loss += loss.item()

        val_loss = val_running_loss / (idx + 1)
        mean_iou = val_iou / (idx + 1)

    print("-"*30)
    print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
    print(f"Mean IoU EPOCH {epoch+1}: {mean_iou:.4f}")
    print("-"*30)

torch.save(model.state_dict(), ".")

100%|██████████| 45/45 [01:22<00:00,  1.82s/it]
100%|██████████| 46/46 [00:35<00:00,  1.29it/s]
 50%|█████     | 1/2 [01:57<01:57, 117.77s/it]

------------------------------
Train Loss EPOCH 1: 0.4838
Valid Loss EPOCH 1: 0.4704
Mean IoU EPOCH 1: 0.3017
------------------------------


