### Import Dependencies

In [1]:
import random
random.seed(0)

In [2]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.transforms import v2

### Constant Variables

In [3]:
DATASET_PATH = './knee-osteoarthritis'

In [4]:
TRAIN_PATH = f'{DATASET_PATH}/train'
VAL_PATH = f'{DATASET_PATH}/val'
TEST_PATH = f'{DATASET_PATH}/test'
AUTO_TEST_PATH = f'{DATASET_PATH}/auto_test'

### Building Dataset 

In [5]:
from src.dataset.augmented_dataset import get_KneeOsteoarthritis_Edges

train_dataset = get_KneeOsteoarthritis_Edges(TRAIN_PATH)
val_dataset = get_KneeOsteoarthritis_Edges(VAL_PATH)

In [6]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

In [None]:
row = train_dataset[1]
normal_ex = row[0]
augmented_ex = row[1]
print(normal_ex.shape, augmented_ex.shape)

### Configuring loader

In [9]:
from torch.utils.data import DataLoader
BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

### Setting up Model Training

#### Setting optimizer

In [None]:
import torch.optim as optim
from src.other import getClassesFrequency
# criterion = nn.NLLLoss()

class_weights = getClassesFrequency(train_dataset)
weights_tensor = torch.Tensor(list(class_weights.values())).to(device)
print(weights_tensor, weights_tensor.dtype)

criterion = nn.CrossEntropyLoss(weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [14]:
def get_lr(optimizer):
  for param_group in optimizer.param_groups:
    return param_group['lr']

#### Setting Up Logger

In [15]:
EXP_NAME = "only-resnet18/1"

In [16]:
from torch.utils.tensorboard import SummaryWriter

logger = SummaryWriter(log_dir=f"logs/{EXP_NAME}")

In [17]:
epochCounter = 0

#### Training Loop

In [18]:
from src.validation import validate

def train_many(model, epochs_nr, regularization_type = "L2", lambda_reg=0.01):
    global epochCounter
    
    for epoch in range(0, epochs_nr):  # loop over the dataset multiple times
        epoch_correct = 0
        epoch_samples = 0
        epoch_batches = 0
        running_loss = 0.0
    
        for i, data in enumerate(train_loader, 0):

            # get the inputs; data is a list of [inputs, labels]
            images, edges, labels = data
            images = images.to(device)
            edges = edges.to(device)
            labels = labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            outputs = model(images, edges)
            loss = criterion(outputs, labels)
            
            # Apply L1 regularization
            if regularization_type == 'L1':
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                loss += lambda_reg * l1_norm
                
            # Apply L2 regularization
            elif regularization_type == 'L2':
                l2_norm = sum(p.pow(2).sum() for p in model.parameters())
                loss += lambda_reg * l2_norm
                
            loss.backward()
            
            optimizer.step()
            
            # Changing outputs (logits) to labels
            outputs_clear = outputs.max(1).indices
            
            epoch_correct += (outputs_clear == labels).float().sum()
            epoch_samples += len(outputs)
            epoch_batches +=1
            
            running_loss += loss.item()
        
        tAccuracy = epoch_correct / epoch_samples * 100
        tLoss = running_loss / epoch_batches
        
        # Validation
        vAccuracy, vLoss = validate(model, val_loader, criterion, device)
        
        logger.add_text("REGULARIZATION_TYPE", regularization_type, global_step=epochCounter)
        logger.add_scalar("REGULARIZATION_LAMBDA", lambda_reg, global_step=epochCounter)
        logger.add_scalar("learning_rate", get_lr(optimizer), global_step=epochCounter)
        
        logger.add_scalar("Accuracy/train", tAccuracy, global_step=epochCounter)
        logger.add_scalar("Loss/train", tLoss, global_step=epochCounter)
        logger.add_scalar("Accuracy/validation", vAccuracy, global_step=epochCounter)
        logger.add_scalar("Loss/validation", vLoss, global_step=epochCounter)
        
        print(f'Epoch {epochCounter}: Training: accuracy: {tAccuracy:.3f}%, loss: {tLoss:.3f}; Validation: accuracy: {vAccuracy:.3f}%, loss: {vLoss:.3f}')
        
        epochCounter += 1
        
    print('Finished Training')

### Training Model

In [None]:
train_many(model, 15, 'L2', 0.001)