In [1]:
import os
import glob

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from models.clip_seg import CLIPSeg
from data_core import myDataset
from utils import *

import clip

In [2]:
dim = 256          # Input image size (height and width)
num_classes = 3     # Trimap {1,2,3}
batch_size = 16
num_epochs = 50

seed = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

clip_model, clip_preprocess = clip.load("RN50", device=device)
clip_model = clip_model.float()
clip_model.eval() 
for param in clip_model.parameters():
    param.requires_grad = False

train_images_dir = f'./trainval_{dim}/images'
train_masks_dir  = f'./trainval_{dim}/annotations'

image_paths = sorted(glob.glob(os.path.join(train_images_dir, '*.png')))
mask_paths  = sorted(glob.glob(os.path.join(train_masks_dir, '*.png')))

print(f'Found {len(image_paths)} images and {len(mask_paths)} masks')

Found 3680 images and 3680 masks


In [3]:
train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
    image_paths, mask_paths, test_size=0.1, random_state=seed, shuffle=True
)
print(f"Training images: {len(train_img_paths)}  Validation images: {len(val_img_paths)}")

# Create dataset instances.
train_dataset = myDataset(train_img_paths, train_mask_paths, dim)
val_dataset   = myDataset(val_img_paths, val_mask_paths, dim)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Training images: 3312  Validation images: 368


In [4]:
model = CLIPSeg(clip_model, num_classes).to(device)
optimizer = optim.Adam(model.seg_head.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss() 

In [5]:
best_val_loss = float('inf') # set to infinity initially

# Training loop 
for epoch in tqdm(range(1, num_epochs+1), desc="Training Epochs"):

    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Save the best model.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"clip_weights/clip_model_{dim}_epochs_{num_epochs}.pth")


Training Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 64, 1, 1], expected input[16, 32, 64, 64] to have 64 channels, but got 32 channels instead