In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from bbox_plotter import yolo_to_corners
from torchvision.ops import complete_box_iou_loss
from PIL import Image
import numpy as np
import csv
import os
from bbox_plotter import visualize_prediction
import pandas as pd

In [3]:
class UnlabelledDataset():
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.jpg.jpg')]
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, img_path

In [4]:
unlabelled_dataset_dir = r"E:\DIL\Fish_classification_ViT\unlabelled_data"

transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor()])

unlabelled_dataset = UnlabelledDataset(unlabelled_dataset_dir, transform=transform)
print(unlabelled_dataset.__len__())

15


In [5]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.global_avg_pool = nn.AdaptiveAvgPool2d((8, 8))  

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features= 128 * 26 * 26, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=4)
        
        
    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))                
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        
        return x
        

In [6]:
class CompleteBoxLoss(nn.Module):
    def __init__(self):
        super(CompleteBoxLoss, self).__init__()

    def forward(self, pred_boxes, true_boxes, reduction):
        ciou = complete_box_iou_loss(pred_boxes, true_boxes, reduction)
        
        return ciou


In [7]:
model = CNNModel()
ciou_loss_function = CompleteBoxLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
def load_checkpoint(checkpoint, architecture, optimizer):
    print("loading checkpoint...")
    checkpoint = torch.load(checkpoint)
    
    model = architecture()
    
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    return model.eval()
model = load_checkpoint("E:\Code\CAMZ\models\CNN_checkpoint.pth.tar", CNNModel, optimizer)

loading checkpoint...


  checkpoint = torch.load(checkpoint)


In [9]:
def generate_pseudo_labels(model, csv_path, dataloader):
    pseudo_labels = []
    image_paths = []
    
    with torch.no_grad():
        for images, paths in dataloader:
            outputs = yolo_to_corners(model(images), 720, 720)
            optimizer.zero_grad()
                
            pseudo_labels.extend(np.array(outputs))
            image_paths.extend(paths)
            
            with open(csv_path, mode='w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['image_path', 'pseudo_label'])
                for img_path, label in zip(image_paths, pseudo_labels):
                    writer.writerow([img_path, label])        
    print("-----------------------------------pseudo labels are generated!----------------------------------------")
        


In [10]:
csv_path = r'E:\DIL\Fish_classification_ViT\unlabelled_data\pseudo_labels.csv' 
unlabelled_dataloader = DataLoader(unlabelled_dataset, batch_size=32, shuffle=False)
generate_pseudo_labels(model,csv_path, dataloader=unlabelled_dataloader)       

-----------------------------------pseudo labels are generated!----------------------------------------


In [11]:
class PseudoLablledDataset(Dataset):
    def __init__(self, img_dir, pseudo_labels_df, transform=None):
        self.img_dir = img_dir
        self.pseudo_labels_df = pseudo_labels_df
        self.transform = transform
        
    def __len__(self):
        return len(self.pseudo_labels_df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.pseudo_labels_df.iloc[idx, 0])
        label = self.pseudo_labels_df.iloc[idx, 1]
        label = label.strip('[]')
        label_a = np.array([float(x) for x in label.split()], dtype=np.float32)
        label = torch.tensor(label_a, dtype=torch.float32)
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [12]:
pseudo_labels_df = pd.read_csv(csv_path) 
pseudo_lables_dataset = PseudoLablledDataset(img_dir=unlabelled_dataset_dir, pseudo_labels_df=pseudo_labels_df,
                                             transform=transform)

pseudo_labels_dataloader = DataLoader(pseudo_lables_dataset, batch_size=32, shuffle=True)

In [14]:
def Initialize_writer(file_path):
    print(file_path)
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, mode='w', newline="") as file:
        writer = csv.writer(file)
        writer.writerow(['epoch', 'loss'])

In [15]:
num_epochs = 2
unlabelled_history_csv_file = r"E:\Code\CAMZ\models\model_history\unlabelled_CNN_loss_logger.csv"
unlablled_checkpoints_file = r"E:\Code\CAMZ\models\model_history\unlabelled_CNN_checkpoint.pth.tar"

In [22]:
def train(model, train_loader, optimizer, 
          loss_function,
          history_csv_path, 
          save_model_checkpoint_path, num_epochs=1, patience=10, delta = 0.001):
    
    Initialize_writer(history_csv_path)
    early_stopping = 0
    
    best_ciou_loss = float('inf')
    
    model.train()
        
    for epoch in range(num_epochs):
        ciou_total_batch_loss = 0.0
         
        for inputs, labels in pseudo_labels_dataloader:
            inputs = inputs
            labels = labels
            
            optimizer.zero_grad()
            
            outputs = yolo_to_corners(model(inputs), image_width=720, image_height=720)
                     
            ciou_loss_per_batch = loss_function(outputs, labels, 'mean')
            
            ciou_loss_per_batch.backward()
            
            optimizer.step()
            
            ciou_total_batch_loss += ciou_loss_per_batch.item()
            
        ciou_loss_per_epoch = ciou_total_batch_loss / len(train_loader)
        
        print(end='\n')
        print(f"Epoch [{epoch+1}/{num_epochs}], CIoU: {ciou_loss_per_epoch:.4f}", end=" --- ")

        
        
        if (best_ciou_loss - ciou_loss_per_epoch) > delta:
            best_ciou_loss = ciou_loss_per_epoch
            checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            print("\n")
            print(f"=========Saving Checkpoint======= at Epoch:[{epoch+1}/{num_epochs}]", end="\n")
            torch.save(checkpoint, save_model_checkpoint_path)
            early_stopping = 0
        
        else:
            early_stopping+=1 
            

        
        with open(history_csv_path, mode='a', newline="") as file:
            loss_writer = csv.writer(file)
            loss_writer.writerow([epoch+1, ciou_loss_per_epoch])
        
        if early_stopping >= patience:
            print(f"Early stopping occured at {epoch+1}")
            break
             
    print(f"The best Validation Loss is: {best_ciou_loss}")

In [23]:
train(model, unlabelled_dataloader, optimizer, ciou_loss_function, unlabelled_history_csv_file,
      unlablled_checkpoints_file, num_epochs=2)

E:\Code\CAMZ\models\model_history\unlabelled_CNN_loss_logger.csv

Epoch [1/2], CIoU: 0.0000 --- 


Epoch [2/2], CIoU: 0.0000 --- The best Validation Loss is: 0.0
