In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from bbox_plotter import yolo_to_corners
from torchvision.ops import complete_box_iou_loss
import numpy as np
import csv
import os
from sklearn.model_selection import KFold

In [2]:
class ZebrafishDataset(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images = np.load(images_path)
        self.labels = np.load(labels_path)
        self.transform = transform
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        
        return torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

In [3]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=128 * 26 * 26, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=4)
        
        
    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))
                
        x = self.flatten(x)
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        
        return x
        

In [4]:
def reset_weights(model):
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

In [5]:
class CompleteBoxLoss(nn.Module):
    def __init__(self):
        super(CompleteBoxLoss, self).__init__()

    def forward(self, pred_boxes, true_boxes, reduction):
        ciou = complete_box_iou_loss(pred_boxes, true_boxes, reduction)
        
        return ciou


In [6]:
def Initialize_writer(file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, mode='w', newline="") as file:
        writer = csv.writer(file)
        writer.writerow(['epoch', 'loss', 'val_loss'])

In [7]:
images_path = "E:\Code\CAMZ\data\interim\X_labelled_data.npy"
labels_path = "E:\Code\CAMZ\data\interim\y_labelled_data.npy"

transform = transforms.ToTensor()
dataset = ZebrafishDataset(images_path, labels_path, transform=transform)

In [8]:
k_folds = 2
num_epochs = 1
loss_function = CompleteBoxLoss()
batch_size = 16

In [None]:
kfold = KFold(n_splits=k_folds, shuffle=True)

fold_results = []

for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f"Fold: {fold+1}")
    print('-----------------------------------------------------------------------------------------')
    
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_subsampler)
    
    model = CNNModel()
    print("Resetting the model for upcoming training session")
    model.apply(reset_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    history_csv_path = f"E:\Code\CAMZ\models\CrossValidation_history\{fold+1}_history"
    Initialize_writer(history_csv_path)
    
    best_val_ciou_loss = float('inf')
    
    for epoch in range(num_epochs):
        #Training Pipeline
        model.train()
        ciou_total_batch_loss = 0.0
         
        for inputs, labels in train_loader:
            
            optimizer.zero_grad()
            
            norm_labels = yolo_to_corners(labels, image_width=720, image_height=720)
            outputs = yolo_to_corners(model(inputs), image_width=720, image_height=720)
                     
            
            ciou_loss_per_batch = loss_function(outputs, norm_labels, 'mean')
            
            ciou_loss_per_batch.backward()
            
            optimizer.step()
            
            ciou_total_batch_loss += ciou_loss_per_batch.item()
            
        ciou_loss_per_epoch = ciou_total_batch_loss / len(train_loader)
        
        print(end='\n')
        print(f"Fold [{fold+1}/{k_folds}], Epoch [{epoch+1}/{num_epochs}], CIoU: {ciou_loss_per_epoch:.4f}", end=" --- ")
        
        # Validation Pipeline
        model.eval()
        ciou_total_batch_loss = 0.0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                norm_labels = yolo_to_corners(labels, image_width=720, image_height=720)
                outputs = yolo_to_corners(model(inputs), image_width=720, image_height=720)
                
                ciou_loss_per_batch = loss_function(outputs, norm_labels, 'mean')
                ciou_total_batch_loss += ciou_loss_per_batch.item()
        
        val_ciou_loss = ciou_total_batch_loss / len(test_loader)
        print(f"Validation for Epoch [{epoch+1}/{num_epochs}], CIoU: {val_ciou_loss}")
        
        # Save model for this epoch if it's the best so far
        if val_ciou_loss < best_val_ciou_loss:
            best_val_ciou_loss = val_ciou_loss
            torch.save(model.state_dict(), f'E:\Code\CAMZ\models\model_fold_{fold + 1}_epoch_{epoch + 1}.pth')
            print(f"Saving best model for Fold {fold + 1} at Epoch {epoch + 1}")


        with open(history_csv_path, mode='a', newline="") as file:
            loss_writer = csv.writer(file)
            loss_writer.writerow([epoch+1, ciou_loss_per_epoch, val_ciou_loss])

    print(f"Best Validation CIoU for Fold {fold + 1}: {best_val_ciou_loss:.4f}")


    fold_results.append((fold + 1, best_val_ciou_loss))

# Print all results at the end
print("\nK-Fold Cross-Validation Results:")
print("Fold\tBest Validation CIoU Loss")
for fold, best_loss in fold_results:
    print(f"{fold}\t{best_loss:.4f}")

# Print the average CIoU loss across folds
avg_val_ciou_loss = sum(loss for _, loss in fold_results) / k_folds
print(f"\nAverage Validation CIoU Loss Across All Folds: {avg_val_ciou_loss:.4f}")

        