## Imports

In [None]:
import copy
import torch
import torchvision
from torch import nn
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.byol_transform import (BYOLTransform, BYOLView1Transform, BYOLView2Transform,)
from lightly.utils.scheduler import cosine_schedule

## Data Preprocessing

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import lightly
import matplotlib.pyplot as plt
from PIL import Image  # Ensure you import this for image handling
import scipy.io

class OCTDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.valid_data = []
        self._process_all_files()

    def _process_all_files(self):
        files = os.listdir(self.data_dir)
        for file_name in files:
            file_path = os.path.join(self.data_dir, file_name)
            mat = scipy.io.loadmat(file_path)
            images = mat['images']
            x, y, nimages = images.shape
            ini, fin = int(y / 4), int(y * 3 / 4)
            for i in range(nimages):
                image = images[:, ini:fin, i].astype(np.float32)
                image = (image - image.min()) / (image.max() - image.min())  # Normalize to [0, 1]
                self.valid_data.append(image)

    def __len__(self):
        return len(self.valid_data)

    def __getitem__(self, idx):
        image = self.valid_data[idx]
        image = Image.fromarray((image * 255).astype(np.uint8))  # Convert to PIL Image for transforms
        if self.transform:
            view1 = self.transform(image)
            view2 = self.transform(image)
        return view1, view2

# Define BYOL augmentations using Lightly's `SimCLRTransform`
byol_augmentations = BYOLTransform(
    view_1_transform = BYOLView1Transform(input_size = 512, gaussian_blur=0.0),
    view_2_transform = BYOLView2Transform(input_size = 512, gaussian_blur =0.0 )
)

# Initialize dataset and dataloader
#train_data_dir = "drive/MyDrive/Research_Data/train_data"
train_data_dir = "/Users/ashleasmith/Desktop/Postgrad CS/Research Project/Research_Data"
train_dataset = OCTDataset(train_data_dir, transform=byol_augmentations)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)



## BYOL Model

In [None]:
class BYOL(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        
        self.backbone = backbone
        self.projection_head = BYOLProjectionHead(512, 1024, 256)
        self.prediction_head = BYOLPredictionHead(256, 1024, 256)
        
        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)
        
        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)
    
    def forward(self, x):
        y = self.backbone(x).flattern(start_dim =1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p
    
    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flattern(start_dim =1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

In [None]:
resnet = torchvision.models.resnet50(pretrained = False)
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)

## Training

In [None]:
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.06, momentum = 0.9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
epochs = 100

In [None]:
print("Start Training")

for epoch in range(epochs):
    totoal_loss = 0
    
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    
    for batch in dataloader:
        x0, x1 = batch[0]
        update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
        update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)
        
        x0 = x0.to(device)
        x1 = x1.to(device)
        p0 = model(x0)
        z0 = model.forward_momentum(x0)
        
        p1 = model(x1)
        z1 = model.forward_momentum(x1)
        
        loss = 0.5*(criterion(p0, z1) + criterion(p1, z0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss /len(dataloader)
    #log to wandb    