In [None]:
from utils import *
from IPython.display import display
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.utils.data import DataLoader
from torch import nn, optim
import random
random.seed(1)

from enum import Enum

class model_type(Enum):
    U_Net = 1,
    DeepLabV3 = 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
patches_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)
def load_model(model:model_type,num_classes=4):
    if model == model_type.U_Net:
        return UNet(num_classes)
    elif model == model_type.DeepLabV3:
        return deeplabv3_resnet50(num_classes=num_classes)
class EarlyStopping:
    def __init__(self, patience=20, delta=0,file_name:str="checkpoint.pth"):
        """
        Args:
            patience (int): How many epochs should we wait after last time validation loss improved.
                            Default: 20
            delta (float): Minimum change in the loss to qualify as an improvement.
                           Default: 0
        """
        self.patience = patience
        self.counter = 0
        self.val_loss_min = float('inf')
        self.delta = delta
        self.early_stop = False
        self.name = file_name

    def __call__(self, val_loss, model,verbose=False):

        if (val_loss + self.delta < self.val_loss_min):
            self.save_checkpoint(val_loss, model,verbose)
            self.val_loss_min = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model,verbose):
        if(verbose):
            print(f'Validation loss decreased ({self.val_loss_min:.6f} -> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.name)
        self.val_loss_min = val_loss

In [None]:
# Parameters
batch_size = 4
num_epochs = 200
learning_rate = 1e-3
decay = 1e-5
num_trials = 1
num_classes = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for i in range(num_trials):

    #Change the enum to U_Net to train a U-net model.
    model = load_model(model_type.DeepLabV3,num_classes=num_classes)
    model.to(device)

    # Get data
    train_subset,val_subset,_ = generate_set(patch_size=224,resize=(1120, 1344))

    # Done to created weights for the loss function, it might be better to do this on patches to include the true distribution based on dynamically added crops
    class_counts = np.zeros(num_classes)
    for page in train_subset: #Sum up the number of pixels for each class for each page
        for class_id in range(num_classes):
            class_counts[class_id] += np.sum(page.gt == class_id)

    total_pixels = np.sum(class_counts) #Count the total of pixels
    frequencies = class_counts / total_pixels
    weights = torch.tensor(np.sqrt(1.0 / frequencies).astype('float32')) # calculate weights
    weights
    weights = weights.to(device)
    # Loss, optimizer, stopper
    criterion = nn.CrossEntropyLoss(weight=weights)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=decay)
    earlystopper= EarlyStopping(20,0, f'./model-small-fold-{i}.pth')


    #Prepare data
    train_dataset = PatchesDataset(train_subset,patches_transforms)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = PatchesDataset(val_subset,patches_transforms)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

    # Training loop
    for epoch in range(num_epochs):

        train_dataset.random_patch_generator(10)
        val_dataset.random_patch_generator(10)
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:

            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)['out']
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * labels.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        #Validation loop
        model.eval()
        running_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)['out']
                loss = criterion(outputs,labels)
                running_loss += loss.item() * labels.size(0)
        val_loss = running_loss/len(val_loader.dataset)

        if epoch > 50:
            earlystopper(val_loss=val_loss,model=model,verbose=False)
        if earlystopper.early_stop == True:
                break
        print(f'Epoch [{epoch+1}/{num_epochs}], Train loss: {loss.item():.4f}, Validation loss {val_loss}')

    print(f'Training finished! after {epoch} epochs, Model saved to ./model-fold-{i}.pth')