# IF5200 - Modeling Notebook
___
Group<t>: 8<br>
Project: Automated Chest X-Ray Report Generator in Bahasa Indonesia with the Use of Deep Learning

## A. Print library version

In [1]:
print('tqdm version:', __import__('tqdm').__version__)
print('pandas version:', __import__('pandas').__version__)
print('pillow version:', __import__('PIL').__version__)
print('torch version:', __import__('torch').__version__)
print('torchvision version:', __import__('torchvision').__version__)

tqdm version: 4.64.1
pandas version: 1.3.5
pillow version: 9.4.0
torch version: 1.13.1+cu117
torchvision version: 0.14.1+cu117


## B. Helper functions

### 1. TrainUtils

In [2]:
import torch
from torch import nn
from torch import optim
from tqdm import tqdm


class TrainUtils:
    
    def __init__(self, model, 
                 loss_fn: str, 
                 optimizer: str, 
                 learning_rate: float = 1e-3, 
                 device: str = None):
        
        super(TrainUtils, self).__init__()
        
        # Supported loss function
        supported_loss_fn = {
            'CrossEntropyLoss': nn.CrossEntropyLoss()
        }
        
        # Supported optimizer
        supported_optimizer = {
            'SGD': optim.SGD(model.parameters(), lr=learning_rate),
            'Adam': optim.Adam(model.parameters(), lr=learning_rate)
        }  
        
        # Set model
        self.model = model
        
        # Set loss function
        if loss_fn not in supported_loss_fn:
            raise ValueError('Loss function is not supported!')
        else:
            self.loss_fn = supported_loss_fn[loss_fn]
        
        # Set optimizer
        if optimizer not in supported_optimizer:
            raise ValueError('Optimizer is not supported!')
        else:
            self.optimizer = supported_optimizer[optimizer]
        
        # Set device
        if device is not None:
            self.device = device
            print('Using GPU!\n')
        else:
            self.device = 'cpu'
            print('Using CPU!\n')

    def train(self, dataloader, 
              print_log: bool = False):
        
        model = self.model
        loss_fn = self.loss_fn
        optimizer = self.optimizer
        device = self.device
        
        loss_history = []
        
        for batch, (X, y) in enumerate(tqdm(dataloader)):
            # Switch to train mode
            model.train()
            
            # Send tensors to the device
            X, y, model = X.to(device), y.to(device), model.to(device)
            
            # Compute loss (error)
            pred = model(X)
            loss = loss_fn(pred, y)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Append batch loss history
            if batch % 100 == 0:
                loss_history.append([batch, loss])
                
        # Print loss history
        if print_log == True:
            print('Loss over batches:')
            print(' Batch\tLoss')
            for item in loss_history:
                print(f' {item[0]}\t{item[1]:>7f}')
    
        # Return loss history
        return (loss_history)

    def test(self, dataloader, print_log=False):
        
        model = self.model
        loss_fn = self.loss_fn
        device = self.device
        
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        
        # Switch to eval mode
        model.eval()
        
        test_loss, correct = 0, 0
        
        with torch.no_grad():
            for X, y in tqdm(dataloader):
                # Send tensors to the device
                X, y, model = X.to(device), y.to(device), model.to(device)
                
                # Make prediction
                pred = model(X)
            
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
        test_loss /= num_batches
        correct /= size
        
        # Print test accuracy and test lost
        if print_log == True:
            print(f'Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}')
        
        # Return test accuracy
        return (correct)

### 2. ModelUtils

In [5]:
from torch import nn
from torchvision.models import resnet18, resnet50, ResNet18_Weights, ResNet50_Weights
