In [None]:
#Import modules 
import torch
import torch.nn as nn
import time as t 

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#residual block 
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, actFunc=nn.ReLU, stride = 1, downsample = False):
        super(ResidualBlock, self).__init__()
        
        #Sequential layer 1 
        self.FC1 = nn.Sequential(
                        nn.BatchNorm1d(in_channels),
                        nn.Linear(in_channels, out_channels),
                        actFunc())
        
        #Sequential layer 1 
        self.FC2 = nn.Sequential(
                        nn.BatchNorm1d(out_channels),
                        nn.Linear(out_channels, out_channels)
                        )
        
        #downsample is never used 
        self.downsample = downsample
        self.actFunc1 = actFunc()
        self.out_channels = out_channels

    #Resnet like model with fully connected layers rather than convolutional layers 
    def forward(self, x):
        residual = x
        out = self.FC1(x)
        out = self.FC2(out)
        #out = self.actFunc(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.actFunc1(out)
        return out

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, layers, img_input_dim = 64, actFunc=nn.ReLU, num_classes = 10):
        super(ResNet, self).__init__()
        self.inplanes = 64
        
        #main activation function 
        self.actFunc = actFunc
        
        #layers 
        self.encoder = nn.Linear(img_input_dim*img_input_dim, self.inplanes)
        self.input_actFunc = actFunc()
        self.hid_layers = self._make_layer(block, self.inplanes, layers, stride = 1)
        self.decoder = nn.Linear(self.inplanes, num_classes)
        
        #Softmax function to decide what the output image should be 
        self.output_actFunc = nn.Softmax(dim=1)
        #Loss function 
        self.criterion = nn.CrossEntropyLoss()
        
    #Make layer function
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = False
        if stride != 1 or self.inplanes != planes:

            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        downsample = False
        layers = []
        layers.append(block(self.inplanes, planes, self.actFunc, stride, downsample))
        self.inplanes = planes
        #Add each block 
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, self.actFunc))

        #Return sequential layer 
        return nn.Sequential(*layers)

    #Forward function 
    def forward(self, x):
        #Encoder 
        x = x.view(x.size(0), -1)
        x = self.encoder(x)
        x = self.input_actFunc(x)
        
        #Hidden layers 
        x = self.hid_layers(x)
        
        #Decoder 
        x = self.decoder(x)
        x = self.output_actFunc(x)
        
        return x
    
    #Evaluate the model 
    def evaluate(self, val_loader):
        """Evaluate the model's performance on the validation set"""
        outputs = [self.validation_step(batch) for batch in val_loader]
        return self.validation_epoch_end(outputs)

    #training and testing model 
    def fit(self, epochs, lr, mo, train_loader, val_loader, opt_func=torch.optim.SGD, print_statement=True):
        """Train the model using gradient descent"""
        history = []
        #Set the optimizer function 
        optimizer = opt_func(self.parameters(), lr, mo)
        
        #For each epoch do the following 
        for epoch in range(epochs):
            t0 = t.time() 
            # Training Phase (Set training mode)
            self.train() 
            #For each batch optimize the model 
            for batch in train_loader:
                loss = self.training_step(batch)
                loss.backward()
                #loss.backward(retain_graph=True)
                optimizer.step()
                optimizer.zero_grad()
                
            # Validation phase (Set evaluation mode to avoid calculating gradients)
            self.eval()
            with torch.inference_mode():
                result = self.evaluate(val_loader)
            result['epoch_time'] = t.time() - t0 
            
            #print the results 
            if print_statement:
                self.epoch_end(epoch, result)
            history.append(result)
        if print_statement:
            print('-----------------------------------------------------')
        return history
    
    #Find the loss for training 
    def training_step(self, batch):
        images, labels = batch 
        out = self.forward(images)                  # Generate predictions
        loss = self.criterion(out, labels) # Calculate loss
        return loss

    #Find the accuracy of given dataset 
    def accuracy(self, outputs, labels):
        _, preds = torch.max(outputs, dim=1)
        return torch.tensor(torch.sum(preds == labels).item() / len(preds))
    
    #Find the accuracy for validation 
    def validation_step(self, batch):
        images, labels = batch 
        out = self.forward(images)                 # Generate predictions
        loss = self.criterion(out, labels)         # Calculate loss
        acc = self.accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss, 'val_acc': acc}
        
    #plot the validation results 
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    #Print statement for final epoch result 
    def epoch_end(self, epoch, result):
        print("Epoch [{}] val_loss: {:.4f}, val_acc: {:.4f}, time: {:.4f} s".format(epoch, result['val_loss'], result['val_acc'], result['epoch_time']))
