### Import Dependencies

In [1]:
import random
random.seed(0)

In [2]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.transforms import v2

### Constant Variables

In [3]:
DATASET_PATH = './knee-osteoarthritis_2'

In [4]:
TRAIN_PATH = f'{DATASET_PATH}/train'
VAL_PATH = f'{DATASET_PATH}/val'
TEST_PATH = f'{DATASET_PATH}/test'

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(device)

In [6]:
classes = range(4)

### Dataset 

In [7]:
from src.dataset.augmented_dataset import get_KneeOsteoarthritis_Edges, KneeOsteoarthritis_Edges

transform_toTensor = transforms.Compose([transforms.ToTensor()])

train_dataset = torchvision.datasets.ImageFolder(TRAIN_PATH, transform_toTensor)
val_dataset = torchvision.datasets.ImageFolder(VAL_PATH, transform_toTensor)
test_dataset = torchvision.datasets.ImageFolder(TEST_PATH, transform_toTensor)

dataset_all = torch.utils.data.ConcatDataset([train_dataset, val_dataset, test_dataset])

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset_all, [0.7, 0.1, 0.2])
train_dataset = KneeOsteoarthritis_Edges(train_dataset)
val_dataset = KneeOsteoarthritis_Edges(val_dataset)
test_dataset = KneeOsteoarthritis_Edges(test_dataset)

In [None]:
print(len(train_dataset), len(val_dataset), len(test_dataset))

In [9]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

In [None]:
row = train_dataset[1]
normal_ex = row[0]
augmented_ex = row[1]
print(normal_ex.shape, augmented_ex.shape)

### Data Loader

In [11]:
from torch.utils.data import DataLoader

BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

### Model

In [12]:
class IntermediarySpaceModel(nn.Module):
    def __init__(self, num_classes: int = 5, dropout: float = 0.5) -> None:
        super().__init__()
        
        # Size of layer block
        S = 24
        
        # Images
        self.imagesClassifier = nn.Sequential(
            nn.Conv2d(3, S*2, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Dropout(p=dropout*0.2),
            nn.Conv2d(S*2, S*2, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Dropout(p=dropout*0.4),
            nn.Conv2d(S*2, S*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout*0.6),
            nn.Conv2d(S*2, S, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Flatten(),
            nn.Dropout(p=dropout*0.8),
            nn.Linear(S * 7 * 7, S*2),
        )

        self.edgesClassifier = nn.Sequential(
            nn.Conv2d(1, S*2, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Dropout(p=dropout*0.4),
            nn.Conv2d(S*2, S*2, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Dropout(p=dropout*0.6),
            nn.Conv2d(S*2, S, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Flatten(),
            nn.Dropout(p=dropout*0.8),
            nn.Linear(S * 6 * 6, S*2),
        )
        
        self.outputCombiner = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(S*4, S*3),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(S*3, S),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(S, num_classes),
        )

    def forward(self, images: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
        
        # Images
        images = self.imagesClassifier(images)
        
        # Edges
        edges = self.edgesClassifier(edges)
        
        # Combining outputs
        concated = torch.cat((images, edges), 1)
        res = self.outputCombiner(concated)
        
        return res

In [13]:
# from src.models.custom import CustomModel

# model = EarlyIntermediarySpaceModel(3, 0)
model = IntermediarySpaceModel(4, 0.5)
model = model.to(device)

In [None]:
# print(sum(p.numel() for p in net.classifier.parameters()) ,sum(p.numel() for p in net.edgesClassifier.parameters()) )
print(sum(p.numel() for p in model.parameters()))

trainable_parameters = filter(lambda p: p.requires_grad, model.parameters())
print(sum(p.numel() for p in trainable_parameters))

### Training Loop

#### Setting optimizer

In [None]:
import torch.optim as optim
from src.other import getClassesFrequency

class_weights = getClassesFrequency(train_dataset)
weights_tensor = torch.Tensor(list(class_weights.values())).to(device)
print(class_weights, weights_tensor)

criterion = nn.CrossEntropyLoss(weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# decayRate100 = 0.8
decayRate100 = 0.4
decayRate1 = decayRate100**(1/100)
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate1)

print(decayRate1)

In [17]:
def get_lr(optimizer):
  for param_group in optimizer.param_groups:
    return param_group['lr']

#### Setting Logger

In [18]:
EXP_NAME = "long_cm"

from torch.utils.tensorboard import SummaryWriter

logger = SummaryWriter(log_dir=f"logs/experiments/{EXP_NAME}")

In [19]:
epochCounter = 0

#### Training Loop

In [35]:
from src.evaluate import evaluate_augmented_model
from src.other import getConfusionMatrixDisplay

def train_many(model, epochs_nr, logger = None, lr_scheduler = None, regularization_type = "L2", lambda_reg=0.01):
    global epochCounter
    
    for epoch in range(0, epochs_nr):  # loop over the dataset multiple times
        epoch_correct = 0
        epoch_samples = 0
        epoch_batches = 0
        running_loss = 0.0
    
        for i, data in enumerate(train_loader, 0):

            # get the inputs; data is a list of [inputs, labels]
            images, edges, labels = data
            images = images.to(device)
            edges = edges.to(device)
            labels = labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            outputs = model(images, edges)
            loss = criterion(outputs, labels)
            
            # Apply L1 regularization
            if regularization_type == 'L1':
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                loss += lambda_reg * l1_norm
                
            # Apply L2 regularization
            elif regularization_type == 'L2':
                l2_norm = sum(p.pow(2).sum() for p in model.parameters())
                loss += lambda_reg * l2_norm
                
            loss.backward()
            
            optimizer.step()
            
            # Changing outputs (logits) to labels
            outputs_clear = outputs.max(1).indices
            
            epoch_correct += (outputs_clear == labels).float().sum()
            epoch_samples += len(outputs)
            epoch_batches +=1
            
            running_loss += loss.item()
        
        tAccuracy = epoch_correct / epoch_samples * 100
        tLoss = running_loss / epoch_batches
        
        # Validation
        vAccuracy, vLoss, report, cm = evaluate_augmented_model(model, criterion, val_loader, device)
        cmDisplay = getConfusionMatrixDisplay(cm, classes)
        
        learning_rate = get_lr(optimizer)
        if logger != None:

            # logger.add("Confusion_matrix", regularization_type, global_step=epochCounter)
            logger.add_text("REGULARIZATION_TYPE", regularization_type, global_step=epochCounter)
            logger.add_scalar("REGULARIZATION_LAMBDA", lambda_reg, global_step=epochCounter)
            logger.add_scalar("learning_rate", learning_rate, global_step=epochCounter)
            
            logger.add_scalar("Accuracy/train", tAccuracy, global_step=epochCounter)
            logger.add_scalar("Loss/train", tLoss, global_step=epochCounter)
            logger.add_scalar("Accuracy/validation", vAccuracy, global_step=epochCounter)
            logger.add_scalar("Loss/validation", vLoss, global_step=epochCounter)
            
            if epochCounter % 10 == 0:
                logger.add_text("Slassification Report", report, global_step=epochCounter)
                logger.add_figure("Confusion matrix", getConfusionMatrixDisplay(trainloader), global_step=epochCounter)
            
        print(f'Epoch {epochCounter}: Training: accuracy: {tAccuracy:.3f}%, loss: {tLoss:.3f}; Validation: accuracy: {vAccuracy:.3f}%, loss: {vLoss:.3f}, lr: {learning_rate:.5f}')
        
        epochCounter += 1
        
        if lr_scheduler != None:
            lr_scheduler.step()
        
        # print("lr= " + str(learning_rate))
    print('Finished Training')

#### Data Visualization

In [23]:
def visualize_cm(cm):
    getConfusionMatrixDisplay(cm)
    plt.show()
    
def visualize_all(model, criterion, loader):
    acc, loss, report, cm = evaluate_augmented_model(model, criterion, loader, device)
    print(f"Accuracy: {acc}, loss: {loss}")
    print(report)
    
    visualize_cm(cm)

### Training Model

In [None]:
train_many(model, 100, None, my_lr_scheduler, "L2", 0.008)

In [None]:
# test
visualize_all(model, criterion, train_loader)
visualize_all(model, criterion, test_loader)

In [None]:
train_many(model, 100, None, my_lr_scheduler, "L2", 0.008)

In [None]:
visualize_all(model, train_loader)
visualize_all(model, test_loader)

In [None]:
train_many(model, 100, None, my_lr_scheduler, "L2", 0.008)

In [None]:
visualize_all(model, train_loader)
visualize_all(model, test_loader)