In [1]:
def print_decorated(message, border_char='-', padding=1, width=80):
    border = border_char * width
    for _ in range(padding):
        print(border)
    print(message.center(width))
    for _ in range(padding):
        print(border)

### Import Necessary libraries

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from bbox_plotter import yolo_to_corners
from torchvision.ops import complete_box_iou_loss
import shutil
from PIL import Image
import pandas as pd
import cv2
import os
import csv
torch.manual_seed(42)
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

### CNN Model Architecture
Change the adaptive pooling layer according to the env (model is saved with different architectures in different systems, but the main model contains adaptive avg. pooling layer)

In [3]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((8, 8))  

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features= 128 * 8 * 8, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=4)
        
        
    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.global_avg_pool(self.pool3(torch.relu(self.conv3(x))))
                
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        
        return x
        

### Loss Functions
#### CIoU for training, validation and evaluation between augmented and original unlabelled predictions
#### Weighted CIoU for labelled and unlabelled combined samples with adaptive ramping up of weight function

In [4]:
class CompleteBoxLoss(nn.Module):
    def __init__(self):
        super(CompleteBoxLoss, self).__init__()

    def forward(self, pred_boxes, true_boxes, reduction):
        ciou = complete_box_iou_loss(pred_boxes, true_boxes, reduction)
        
        return ciou
    
ciou_loss_function = CompleteBoxLoss()

In [5]:
class WeightedCompleteBoxLoss(nn.Module):
    def __init__(self, unlabeled_weight=1.5):
        super(WeightedCompleteBoxLoss, self).__init__()
        self.unlabeled_weight = unlabeled_weight

    def forward(self, pred_boxes, true_boxes, is_labeled):
        ciou = complete_box_iou_loss(pred_boxes, true_boxes)
        labeled_loss = ciou[is_labeled == 1] 
        pseudo_labeled_loss = ciou[is_labeled == 0]
        pseudo_labeled_loss = self.unlabeled_weight * pseudo_labeled_loss
        total_loss = torch.cat([labeled_loss, pseudo_labeled_loss]).mean()
        return total_loss

### Defining and Loading model

In [6]:
model = CNNModel().to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
def load_checkpoint(checkpoint, architecture, optimizer):
    print("loading checkpoint...")
    checkpoint = torch.load(checkpoint)
    
    model = architecture()
    
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    model = model.to('cuda')
    return model.eval()

### Dataset Classes
- For labelled, unlabelled and combined datasets

In [8]:
class LabeledDataset(Dataset):
    def __init__(self, image_folder, label_subfolder, transform=None):
        self.image_folder = image_folder
        self.label_folder = os.path.join(image_folder, label_subfolder) 
        self.image_files = sorted(os.listdir(image_folder))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, self.image_files[idx])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image) 

        label_filename = os.path.splitext(self.image_files[idx])[0] + '.txt'
        label_path = os.path.join(self.label_folder, label_filename)
        labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    class_id, x, y, w, h = map(float, line.strip().split())
                    labels.append([class_id, x, y, w, h])

        labels = torch.tensor(labels, dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, labels


class UnlabeledDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.image_files = sorted(os.listdir(image_folder))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, self.image_files[idx])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, image_path

In [9]:
class ZebraFishDataset(Dataset):
    def __init__(self, labeled_image_folder, pseudo_labeled_image_folder, label_folder, pseudo_label_folder, transform=None):
        self.labeled_image_folder = labeled_image_folder
        self.pseudo_labeled_image_folder = pseudo_labeled_image_folder
        self.label_folder = label_folder
        self.pseudo_label_folder = pseudo_label_folder
        self.transform = transform
        self.labeled_images = [f for f in os.listdir(labeled_image_folder) if f.endswith('.jpg')]
        self.pseudo_labeled_images = [f for f in os.listdir(pseudo_labeled_image_folder) if f.endswith('.jpg')]
        self.all_images = self.labeled_images + self.pseudo_labeled_images
        self.is_labeled = torch.tensor([1] * len(self.labeled_images) + [0] * len(self.pseudo_labeled_images))

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_name = self.all_images[idx]
        if self.is_labeled[idx]:
            img_path = os.path.join(self.labeled_image_folder, img_name)
            label_path = os.path.join(self.label_folder, img_name.rsplit(".jpg", 1)[0] + ".txt")
        else:
            img_path = os.path.join(self.pseudo_labeled_image_folder, img_name)
            label_path = os.path.join(self.pseudo_label_folder, img_name.rsplit(".jpg", 1)[0] + ".txt")
        with open(label_path, 'r') as f:
            label = f.read().strip().split()
        label = [float(x) for x in label[1:]]  
        label = torch.tensor(label, dtype=torch.float32)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label, self.is_labeled[idx]
    

class ZebrafishValDataset(Dataset):
    def __init__(self, image_folder, label_subfolder, transform=None):
        self.image_folder = image_folder
        self.label_folder = os.path.join(image_folder, label_subfolder) 
        self.image_files = sorted(os.listdir(image_folder))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, self.image_files[idx])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image) 

        label_filename = os.path.splitext(self.image_files[idx])[0] + '.txt'
        label_path = os.path.join(self.label_folder, label_filename)
        label = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    class_id, x, y, w, h = map(float, line.strip().split())
                    label.append([x, y, w, h])

        if self.transform:
            image = self.transform(image)

        return torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32).squeeze(0)

In [10]:
def load_datasets(labeled_image_folder, label_subfolder, unlabeled_image_folder):
    transform = transforms.Compose([
        transforms.Resize((224, 224))
        ])
    labeled_dataset = LabeledDataset(labeled_image_folder, label_subfolder, transform=transform)
    unlabeled_dataset = UnlabeledDataset(unlabeled_image_folder, transform=transform)
    return labeled_dataset, unlabeled_dataset

In [11]:
labeled_image_folder = r"D:\Praharsha\code\CAMZ\data\raw\Labelled_images"
label_subfolder = r"D:\Praharsha\code\CAMZ\data\raw\Labelled_images\labels"
unlabeled_image_folder = r"D:\Praharsha\code\CAMZ\data\raw\unlabelled_images"
pseudo_labelled_images = r"D:\Praharsha\code\CAMZ\data\raw\pseudo_labelled_images"
pseudo_labels_subfolder = r"D:\Praharsha\code\CAMZ\data\raw\pseudo_labelled_images\pseudo_labels"
val_image_data = r"D:\Praharsha\code\CAMZ\data\raw\validation_data"
val_labels_subfolder = r"D:\Praharsha\code\CAMZ\data\raw\validation_labels"

### Generate pseudo labels from unlabelled data

In [12]:
def Generate_pseudo_labels(csv_filename, unlabeled_dataset, model, difference_metric=ciou_loss_function):
    # Define augmentations
    augmentations = [
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ColorJitter(brightness=0.3, contrast=0.3),
        transforms.GaussianBlur(3),
        transforms.RandomSolarize(threshold=0.5, p=1.0)
    ]

    # Collect entries in a list
    entries = []

    # Process each image in the unlabeled dataset
    for image_idx in range(len(unlabeled_dataset)):
        predictions = []
        augmented_predictions_list = []
        image, file_name = unlabeled_dataset[image_idx]
        image_tensor = transforms.ToTensor()(image)
        input_image_tensor = image_tensor.unsqueeze(0)
        input_image_tensor = input_image_tensor.to('cuda')
        # Get the original prediction
        model.eval()
        with torch.no_grad():
            original_prediction = model(input_image_tensor)
            predictions.append(original_prediction)

        # Get predictions for augmented images
        for i, augmentation in enumerate(augmentations):
            augmented_image = augmentation(image)
            augmented_image_tensor = transforms.ToTensor()(augmented_image)
            augmented_input_image_tensor = augmented_image_tensor.unsqueeze(0)
            augmented_input_image_tensor = augmented_input_image_tensor.to('cuda')

            model.eval()
            with torch.no_grad():
                augmented_prediction = model(augmented_input_image_tensor)
                if i == 0:
                    augmented_prediction[0][1] = 1 - augmented_prediction[0][1]
                augmented_predictions_list.append(augmented_prediction)

        # Compute the mean of augmented predictions
        augmented_predictions_tensor = torch.stack(augmented_predictions_list)
        mean_augmented_prediction = augmented_predictions_tensor.mean(dim=0)

        # Compute the final mean prediction
        mean_prediction = torch.stack([original_prediction, mean_augmented_prediction]).mean(dim=0)

        # Convert predictions to corners format
        original_prediction = yolo_to_corners(original_prediction, image_height=224, image_width=224)
        mean_augmented_prediction = yolo_to_corners(mean_augmented_prediction, image_height=224, image_width=224)

        # Compute the CIoU loss
        ciou_loss = difference_metric(original_prediction, mean_augmented_prediction, reduction='mean')

        # Add the new entry to the list
        entries.append({
            "filename": file_name,
            "bounding_box": mean_prediction.tolist(),
            "ciou_loss": ciou_loss.item()
        })

    # Create the DataFrame once at the end
    df = pd.DataFrame(entries)

    # Save the DataFrame to a CSV file
    df.to_csv(csv_filename, index=False)
    print_decorated(f"Generated Pseudo labels for Unlabelled Dataset ({len(unlabeled_dataset)} samples)", border_char='.')

### Filtering top-K labels from ciou loss between original unlabelled predictions and augmented predictions

In [13]:
def filter_top_pseudo_labels(csv_filename, pseudo_labeled_folder, pseudo_labels_folder, unlabeled_folder,  num_top=5):
    df = pd.read_csv(csv_filename)
    top_df = df.sort_values(by="ciou_loss").head(num_top)
    top_files = top_df["filename"].tolist()
    top_boxes = top_df["bounding_box"].apply(eval).tolist() 
    for file, box in zip(top_files, top_boxes):
        src_path = os.path.join(unlabeled_folder, os.path.basename(file))
        dest_path = os.path.join(pseudo_labeled_folder, os.path.basename(file))
        if os.path.exists(src_path):
            shutil.move(src_path, dest_path)
        txt_filename = os.path.join(pseudo_labels_folder, os.path.basename(file).split(".")[0] + ".jpg.txt")
        with open(txt_filename, "w") as f:
            f.write(f"0 {round(box[0][0], 5)} {round(box[0][1], 5)} {round(box[0][2], 5)} {round(box[0][3], 5)}\n")
    print_decorated(f"Filtered Top-{num_top} Label", border_char='*')

In [14]:
def Initialize_writer(file_path,columns = ['epoch','loss','val_loss']):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, mode='w', newline="") as file:
        writer = csv.writer(file)
        writer.writerow(columns)

In [15]:
def validate(model, val_loader, loss_function, epoch):
    model.eval()
    ciou_total_batch_loss = 0.0
    ciou_loss_per_batch_val_history = []
    batch_num = 1
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            outputs = yolo_to_corners(model(inputs), image_width=224, image_height=224)
            norm_labels = yolo_to_corners(labels, image_width=224, image_height=224)
            
            ciou_loss_per_batch = loss_function(outputs, norm_labels, 'mean')
            ciou_loss_per_batch_val_history.append([epoch+1, batch_num, ciou_loss_per_batch.item()])
            ciou_total_batch_loss += ciou_loss_per_batch.item()
            batch_num += 1
        
    ciou_total_loss = ciou_total_batch_loss / len(val_loader)
    print(f"Validation for Epoch [{epoch+1}], CIoU: {ciou_total_loss:.4f}")
    return ciou_total_loss, ciou_loss_per_batch_val_history

In [16]:
def train(model, train_loader, val_loader, optimizer, loss_function, val_loss_function,
          epoch_history_csv_path, batch_train_history_csv_path, batch_val_history_csv_path, 
          save_model_checkpoint_path, num_epochs=1, patience=20, delta=0.001):
    
    Initialize_writer(epoch_history_csv_path)
    Initialize_writer(batch_train_history_csv_path, columns=['epoch', 'batch', 'loss'])
    Initialize_writer(batch_val_history_csv_path, columns=['epoch', 'batch', 'loss'])
    best_val_ciou_loss = float('inf')
    model.train()        
    early_stopping = 0
    for epoch in range(num_epochs):
        ciou_total_batch_loss = 0.0
        ciou_loss_per_batch_train_history = []
        batch_num = 1
        for inputs, labels, is_labeled in train_loader:
            inputs = inputs.to('cuda')
            labels = labels.to('cuda')
            is_labeled = is_labeled.to('cuda')
            optimizer.zero_grad()
            norm_labels = yolo_to_corners(labels, image_width=720, image_height=720)
            outputs = yolo_to_corners(model(inputs), image_width=720, image_height=720)
            ciou_loss_per_batch = loss_function(outputs, norm_labels, is_labeled)
            ciou_loss_per_batch.backward()
            optimizer.step()
            ciou_loss_per_batch_train_history.append([epoch+1, batch_num, ciou_loss_per_batch.item()])
            ciou_total_batch_loss += ciou_loss_per_batch.item()
            batch_num+=1
        ciou_loss_per_epoch = ciou_total_batch_loss / len(train_loader)
        print(end='\n')
        print(f"Epoch [{epoch+1}/{num_epochs}], CIoU: {ciou_loss_per_epoch:.4f}", end=" --- ")
        print(f"Validation for Epoch [{epoch+1}/{num_epochs}]", end=", ")
        val_ciou_loss, ciou_loss_per_batch_val_history = validate(model, val_loader, val_loss_function, epoch)        
        if (best_val_ciou_loss - val_ciou_loss) > delta:
            best_val_ciou_loss = val_ciou_loss
            checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            print("\n")
            print(f"=========Saving Checkpoint======= at Epoch:[{epoch+1}/{num_epochs}]", end="\n")
            torch.save(checkpoint, save_model_checkpoint_path)
            early_stopping = 0
        else:
            early_stopping += 1         
        with open(epoch_history_csv_path, mode='a', newline="") as file:
            loss_writer = csv.writer(file)
            loss_writer.writerow([epoch+1, ciou_loss_per_epoch, val_ciou_loss])
        
        with open(batch_train_history_csv_path, 'a', newline="") as file:
            loss_writer = csv.writer(file)
            loss_writer.writerows(ciou_loss_per_batch_train_history)
        
        with open(batch_val_history_csv_path, 'a', newline="") as file:
            loss_writer = csv.writer(file)
            loss_writer.writerows(ciou_loss_per_batch_val_history)
        if early_stopping >= patience:
            print(f"Early stopping occurred at {epoch+1}", end='\n')
            break
    print(f"The best Validation Loss is: {best_val_ciou_loss}")
    return best_val_ciou_loss

In [17]:
best_iter_val_loss = float('inf')
max_iterations = 23
iteration = 0
breaking_point = 0.01
iter_break_count = 0
patience = 3
while iteration < max_iterations:
    print_decorated(f"Starting Semi-supervised Training - Round {iteration+1}/{max_iterations}", border_char="*")
    # load the datasets
    csv_filename = rf"D:\Praharsha\code\CAMZ\models\pseudo_labels_semi{iteration}.csv"
    labelled_dataset, unlabeled_dataset = load_datasets(labeled_image_folder, label_subfolder, unlabeled_image_folder)
    # load the model
    if iteration == 0:
        model = load_checkpoint("D:\Praharsha\code\CAMZ\models\model_history\CNN_checkpoint.pth.tar", CNNModel, optimizer)
        model = model.to('cuda')
    if iteration > 0:
         model = load_checkpoint("D:\Praharsha\code\CAMZ\models\model_history\CNN_checkpoint_semi.pth.tar", CNNModel, optimizer)
         model = model.to('cuda')
    # generate pseudo labels
    Generate_pseudo_labels(csv_filename, unlabeled_dataset, model, ciou_loss_function)
    if (len(unlabeled_dataset) > 0):
        # filter pseudo labels and generate new datasets
        filter_top_pseudo_labels(csv_filename, pseudo_labelled_images, pseudo_labels_subfolder, unlabeled_image_folder, num_top=1000)
        # prepare the new dataset for training
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
            ])
        train_dataset = ZebraFishDataset(labeled_image_folder, pseudo_labelled_images, label_subfolder, pseudo_labels_subfolder, transform=transform)
        val_dataset = ZebrafishValDataset(val_image_data, val_labels_subfolder, transform=transform)
        train_loader = DataLoader(dataset = train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(dataset = val_dataset, batch_size=32, shuffle=True)
        # # train the model with new datasets
        weighted_ciou_loss_function = WeightedCompleteBoxLoss()
        epoch_history_csv_file = f"D:\Praharsha\code\CAMZ\models\Semi_supervised_training_history\CNN_loss_logger_epoch_wise_semi{iteration}.csv"
        batch_histroy_csv_file = f"D:\Praharsha\code\CAMZ\models\Semi_supervised_training_history\CNN_loss_logger_batch_wise_semi{iteration}.csv"
        batch_val_history_csv_file = f"D:\Praharsha\code\CAMZ\models\Semi_supervised_training_history\CNN_loss_logger_val_batch_wise_semi{iteration}.csv"
        model_checkpoint_path = "D:\Praharsha\code\CAMZ\models\model_history\CNN_checkpoint_semi.pth.tar"
        current_iter_val_loss = train(model, train_loader, val_loader, optimizer, weighted_ciou_loss_function ,save_model_checkpoint_path=model_checkpoint_path,
                                      epoch_history_csv_path=epoch_history_csv_file, batch_train_history_csv_path=batch_histroy_csv_file, batch_val_history_csv_path=batch_val_history_csv_file,
                                      val_loss_function=ciou_loss_function, num_epochs=200)
        
        if (best_iter_val_loss - current_iter_val_loss) < breaking_point:
            iter_break_count += 1
            if iter_break_count == patience:
                print_decorated("There is no further improvement observed!!!", border_char='*')
                break
        best_iter_val_loss = current_iter_val_loss
    elif len(unlabeled_dataset) == 0:
        print_decorated("Unlabelled Dataset is all used up!!!", '*')
        break
    iteration += 1
print_decorated(f"Semi-Supervised Learning Piepline is Executed Sucessfully!", '*')
print_decorated(f"The latest best validation CIoU Loss is: {best_iter_val_loss}", '.')

********************************************************************************
                 Starting Semi-supervised Training - Round 1/23                 
********************************************************************************
loading checkpoint...
................................................................................
         Generated Pseudo labels for Unlabelled Dataset (24740 samples)         
................................................................................
********************************************************************************
                            Filtered Top-1000 Label                             
********************************************************************************

Epoch [1/200], CIoU: 0.2238 --- Validation for Epoch [1/200], Validation for Epoch [1], CIoU: 0.3572



Epoch [2/200], CIoU: 0.2238 --- Validation for Epoch [2/200], Validation for Epoch [2], CIoU: 0.3560



Epoch [3/200], CIoU: 0.2238 --- Validati