# UNet Fine Tuning

## Importing Modules

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import os
from tqdm.auto import tqdm

import warnings
warnings.filterwarnings("ignore")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"On Device: {device}")

### Drive Mount

In [None]:
from google.colab import drive
drive.mount("/content/drive/")

## Defining Dataset and DataLoader

In [None]:
train_path_list = []
test_path_list = []

In [None]:
# train data path list
clean_train_paths = sorted(os.listdir('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/train'))
blurred_train_paths = sorted(os.listdir('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/blurred_train'))

for clean, blur in zip(clean_train_paths, blurred_train_paths):
    train_path_list.append([
        os.path.join('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/train', clean),
        os.path.join('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/blurred_train', blur)
    ])

# test data path list
clean_test_paths = sorted(os.listdir('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/test'))
blurred_test_paths = sorted(os.listdir('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/blurred_test'))

for clean, blur in zip(clean_test_paths, blurred_test_paths):
    test_path_list.append([
        os.path.join('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/test', clean),
        os.path.join('/content/drive/MyDrive/Restoring Brain CT and Brain MRI/data/Images/blurred_test', blur)
    ])

In [None]:
# defining dataset
class dataset(Dataset):

  def __init__(self, img_list, data_transform):
    super().__init__()
    self.input = img_list
    self.transforms = data_transform

  def __len__(self):
    return len(self.input)

  def __getitem__(self, idx):
    target_img = Image.open(self.input[idx][0]).convert("RGB")
    train_img = Image.open(self.input[idx][1]).convert("RGB")

    target_img = self.transforms(target_img)
    train_img = self.transforms(train_img)

    return train_img, target_img

In [None]:
img_transforms = transforms.Compose([
    transforms.Resize(size=(256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(torch.tensor([0.1317, 0.1317, 0.1318]), torch.tensor([0.2261, 0.2261, 0.2261]))
])

In [None]:
train_data = dataset(train_path_list, img_transforms)
test_data = dataset(test_path_list, img_transforms)

In [None]:
len(train_data), len(test_data)

In [None]:
BATCH_SIZE = 32
RANDOM_SEED = 42

torch.cuda.manual_seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
# defining dataloader
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())

test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=os.cpu_count())

In [None]:
len(train_dataloader), len(test_dataloader)

## Training

In [None]:
%cd /content/drive/MyDrive/Restoring\ Brain\ CT\ and\ Brain\ MRI/
from model_trainer import train

In [None]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True).to(device)

In [None]:
model.conv = nn.Conv2d(32, 3, kernel_size=(1, 1), stride = (1, 1)).to(device)

In [None]:
model

In [None]:
for param in model.encoder1.parameters():
  param.requires_grad = False

for param in model.encoder2.parameters():
  param.requires_grad = False

for param in model.encoder3.parameters():
  param.requires_grad = False

for param in model.encoder4.parameters():
  param.requires_grad = False

In [None]:
loss_fn = nn.MSELoss()

optimizer = torch.optim.Adam(params=model.parameters(),
                             lr = 0.001)

scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)

In [None]:
folder_path = "/content/drive/MyDrive/Restoring Brain CT and Brain MRI/epochs_output/unet1/" # saving first test image's output after every epoch
EPOCHS1 = 10

train_loss1, test_loss1, test_ssim1, test_psnr1 = train(model,
                                                        optimizer,
                                                        loss_fn,
                                                        scheduler,
                                                        train_dataloader,
                                                        test_dataloader,
                                                        EPOCHS1,
                                                        test_path_list,
                                                        img_transforms,
                                                        device,
                                                        save_output_path = folder_path
                                                        )

In [None]:
plt.plot(train_loss1, label='train_loss')
plt.plot(test_loss1, label='test_loss')
plt.plot(test_ssim1, label='test_ssim')
plt.plot(test_psnr1, label='test_psnr')

plt.xlabel("Epochs")
plt.legend()
plt.show()

In [None]:
# Saving our PyTorch model
from pathlib import Path

# 1. Create models directory
MODEL_PATH = Path("/content/drive/MyDrive/Restoring Brain CT and Brain MRI/models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

# 2. Create model save path
MODEL_NAME = "unet_epoch10.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=model.state_dict(),
           f=MODEL_SAVE_PATH)