In [1]:
!pip install torch-snippets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
from torch import nn as nn
import math
from glob import glob
from torchsummary import summary
from torch_snippets import *
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image 
import sys
import os, os.path as osp


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
!cp /content/drive/Shareddrives/MO434/dataset/project_dataset.zip .
!unzip -q project_dataset.zip

### Hyperparameters

In [10]:
batchsize = 32
n_class = 9
input_shape = (3,224,224)
device = torch.device('cuda')

lr = 1e-5
weight_decay=1e-4
n_epochs = 100

scheduler_step=25
scheduler_gamma = 0.1

dataset_dir = "/content/project_dataset"

log_dir='/content/drive/Shareddrives/MO434/executions/classification_log'
best_model_path = osp.join(log_dir, 'best_model.pth')

try:
    os.makedirs(log_dir)
except FileExistsError:
    pass


### Define dataset

Training, validation

In [11]:
import cv2
from PIL import Image
import os, os.path as osp

def list_join_dir(dir):
    return sorted([osp.join(dir, f) for f in os.listdir(dir)])

class FishDataloader():
    def __init__(self, set_dir, apply_mask=False, transforms=None) -> None:
        self.set_dir = set_dir
        self.apply_mask = apply_mask
        self.transforms = transforms

        self.img_dir = osp.join(set_dir, 'images')
        self.mask_dir = osp.join(set_dir, 'masks')

        self.images, self.gts = [], []
        for img_path in list_join_dir(self.img_dir):
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            img_basename = osp.basename(img_path)
            
            if self.apply_mask:
                mask_path = osp.join(self.mask_dir, img_basename)
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                img = cv2.bitwise_and(img, img, mask=mask)

            self.images.append(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
            self.gts.append(int(img_basename.split('.')[0].split('_')[0]))

    def __getitem__(self, idx):
        if self.transforms:
            return self.transforms(self.images[idx]), self.gts[idx]
        return self.images[idx], self.gts[idx]
    
    def __len__(self):
        return len(self.images)


### Define transforms

In [12]:
from torchvision import transforms

val_transforms = transforms.Compose([
    transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BILINEAR, 
                      max_size=None, antialias=True),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))   
])

train_transforms = transforms.Compose([
    transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BILINEAR, 
                      max_size=None, antialias=True),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))   
])

### Define dataloaders

In [13]:
from torch.utils.data import DataLoader

train_set = FishDataloader(osp.join(dataset_dir, 'train'), transforms=train_transforms)
val_set = FishDataloader(osp.join(dataset_dir, 'val'), transforms=val_transforms)

train_loader = DataLoader(train_set, batch_size=batchsize, shuffle=True)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

### Define CNN architecture

In [14]:
import torch
from torch import nn as nn
import math

class CMKNet(nn.Module):
    def __init__(self, dropout=0.5) -> None:
        super().__init__()
    
        self.dropout = dropout

        # input: 3, 224, 224
        self.col1 = nn.Sequential(
            self._conv_seq(3, 32, 3), # 32, 112, 112
            self._conv_seq(32, 64, 3), # 64, 56, 56
            self._conv_seq(64, 128, 3), # 128, 28, 28
        )

        self.col2 = nn.Sequential(
            self._conv_seq(3, 24, 5), # 32, 112, 112
            self._conv_seq(24, 48, 5), # 64, 56, 56
            self._conv_seq(48, 96, 5), # 128, 28, 28
        )

        self.fuse = self._conv_seq(128+96, 32, 1) # 32, 14, 14

        self.classifier = nn.Sequential(
            self._linear_seq(32*14*14, 1024, dropout),
            self._linear_seq(1024, 256, dropout),
            nn.Linear(256, n_class)
        )

        self._initialize_weights()

    def forward(self, x):
        x1 = self.col1(x)
        x2 = self.col2(x)
        x = self.fuse(torch.cat((x1, x2), 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _fuse(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def _linear_seq(self, in_features, out_features, dropout):
        return nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
    def _conv_seq(self, in_channels, out_channels, kernel_size):
        "conv2d, batchnorm, relu, max pool"
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size//2, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
    
    def _initialize_weights(self):
        #for each submodule of our network
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #get the number of elements in the layer weights
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels    
                #initialize layer weights with random values generated from a normal
                #distribution with mean = 0 and std = sqrt(2. / n))
                m.weight.data.normal_(mean=0, std=math.sqrt(2. / n))

                if m.bias is not None:
                    #initialize bias with 0 
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                #initialize layer weights with random values generated from a normal
                #distribution with mean = 0 and std = 1/100
                m.weight.data.normal_(mean=0, std=0.01)
                if m.bias is not None:
                #initialize bias with 0 
                    m.bias.data.zero_()


### Create the model

In [15]:
from torchsummary import summary

model = CMKNet().to(device)
summary(model,input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]             864
       BatchNorm2d-2         [-1, 32, 224, 224]              64
              ReLU-3         [-1, 32, 224, 224]               0
         MaxPool2d-4         [-1, 32, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]          18,432
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
         MaxPool2d-8           [-1, 64, 56, 56]               0
            Conv2d-9          [-1, 128, 56, 56]          73,728
      BatchNorm2d-10          [-1, 128, 56, 56]             256
             ReLU-11          [-1, 128, 56, 56]               0
        MaxPool2d-12          [-1, 128, 28, 28]               0
           Conv2d-13         [-1, 24, 224, 224]           1,800
      BatchNorm2d-14         [-1, 24, 2

### Define the loss function, optimizer with L2 regularization and scheduler

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)

### Define training and validation functions

In [17]:
def train_batch(model, data, optimizer, criterion, device):
    optimizer.zero_grad()
    ims, targets = data
    ims     = ims.to(device=device)
    targets = targets.to(device=device)
    preds   = model(ims)
    loss = criterion(preds, targets)
    loss.backward()
    optimizer.step()

    return loss.item()

@torch.no_grad()
def validate(model, loader, criterion, device):
    #loader batchsize should be 1 for validation
    model.eval()
    rights = 0
    errors = 0
    cnt = len(loader)
    loss = 0
    for _, data in enumerate(loader):
        img, label = data
        img = img.to(device)
        label = label.to(device)
        pred = model(img)
        if torch.argmax(pred).item() == label.item():
            rights += 1
        else:
            errors += 1
        loss += criterion(pred, label)
    return loss/cnt, rights/cnt # average loss and accuracy


### Train the model and report the results

In [None]:
# initializing variables
log      = Report(n_epochs)
model.train()

best_model_loss = sys.maxsize

for epoch in range(n_epochs):
    N = len(train_loader)
    epoch_train_loss = 0
    for bx, data in enumerate(train_loader):
        loss = train_batch(model, data, optimizer, criterion, device)
        epoch_train_loss += loss

    loss, acc = validate(model, val_loader, criterion, device)
    log.record(epoch+1, train_loss=epoch_train_loss/N, val_loss=loss, val_acc=acc, end='\r')

    lr_scheduler.step()

    if loss < best_model_loss:
        torch.save(model.state_dict(), best_model_path)
        best_model_loss = loss


EPOCH: 2.000  train_loss: 1.734  val_loss: 1.295  val_acc: 0.558  (91.39s - 4478.22s remaining)

### Plot the results

In [None]:
log.plot_epochs(['trn_loss','val_loss'])
log.plot_epochs(['trn_acc','val_acc'])

### Evaluate the model

In [None]:
# load best model
del train_loader
del train_set
del val_loader
del val_set

test_set = FishDataloader(osp.join(dataset_dir, 'test'), transforms=val_transforms)
test_loader  = DataLoader(test_set, batch_size=1, shuffle=False)

model = CMKNet()
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.to(device)

loss, acc = validate(model, test_loader, criterion, device)
print('Loss: {:.6f} Acc: {:.6f}'.format(loss,acc))