In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import pytorch_lightning as pl

class ConvNetPL(pl.LightningModule):
    def __init__(self, pretrained_model_name, pretrained_model_path, num_classes, base_lr, batch_size, train_path, test_path):
        super(ConvNetPL, self).__init__()

        self.pretrained_model_name = pretrained_model_name
        self.pretrained_model_path = pretrained_model_path
        self.num_classes = num_classes
        self.lr = base_lr
        self.train_path = train_path
        self.test_path = test_path
        # dict to store training progress
        self.history = {'train_loss': [],
               'val_loss': [],
               'train_acc':[],
               'val_acc':[]
               }
        self.in_feat = None
        self.model = None
        
        # transfer learning parameters
        self.classifiers_n = -1          
        self.features_n = -1
        
        # check for GPU availability
        use_gpu = torch.cuda.is_available()

        # load model architectures without weight
        if use_gpu:
            self.model = getattr(models, self.pretrained_model_name)().cuda()
        else:
            self.model = getattr(models, self.pretrained_model_name)()

        # load pre-trained weights
        self.model.load_state_dict(torch.load(self.pretrained_model_path))

        # get input dimension of the fc layer to be replaced and index of the last fc layer
        self.in_feat = self.model.classifier[-1].in_features
        fc_idx = len(self.model.classifier) - 1

        custom_fc = nn.Sequential(nn.Linear(self.in_feat, 512),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(512, self.num_classes),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.LogSoftmax(dim=1))

        # add custom fc layers to model
        self.model.classifier[fc_idx] = custom_fc
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    # freezes all layers in the model
    def freeze_all_layers(self):
        for param in self.parameters():
            param.requires_grad = False
    
    # unfreeze last 'n' fully connected layers
    def unfreeze_last_n_fc_layers(model, n):

        # if n == -1 don't unfreeze any layers
        if n == -1:
            return 0

        n = n*2 # since weights and bias are included as separate
        total_layers = len(list(self.classifier.parameters()))

        # invalid n
        if n > total_layers:
            print(f"Warning: There are only {total_layers} layers in the model. Cannot unfreeze {n} layers.")

        # if n == 0 unfreeze all layers
        elif n == 0:
            for param in self.classifier.parameters():
                param.requires_grad = True
        else:
            for i, param in enumerate(self.classifier.parameters()):
                if i >= (total_layers - n):
                    param.requires_grad = True
                else:
                    param.requires_grad = False
                    

    # unfreeze last 'n' fully connected layers
    def unfreeze_last_n_conv_layers(self, n):
        
        # if n == -1 don't unfreeze any layers
        if n == -1:
            return 0

        n = n*2 # since weights and bias are included as separate
        total_layers = len(list(self.features.parameters()))

        # invalid n
        if n > total_layers:
            print(f"Warning: There are only {total_layers} layers in the model. Cannot unfreeze {n} layers.")
        # if n == 0 unfreeze all layers
        elif n == 0:
            for param in self.features.parameters():
                param.requires_grad = True
        else:
            for i, param in enumerate(self.features.parameters()):
                if i >= total_layers - n:
                    param.requires_grad = True
                else:
                    pass
    
    # set parameters for transfer learning
    def set_transfer_learning_params(self, unfreeze_n_fc, unfreeze_n_conv):
        self.classifier_n = unfreeze_n_fc
        self.features_n = unfreeze_n_conv
        self.freeze_all_layers()
        self.unfreeze_last_n_fc_layers(unfreeze_n_fc)
        self.unfreeze_last_n_conv_layers(unfreeze_n_conv)
    
    
    def get_optimizer_params_list(self):
        # list of dictionaries to store parameter values
        params_list = []

        # dividing factor
        f_fc = 2
        f_conv = 3

        if self.classifier_n != -1:
            if self.classifier_n == 0:
                named_params = list(name for name, _ in model.model.classifier.named_parameters())
                layer_indices = list(set([int(name.split('.')[0]) for name in named_params]))
            else:
                # get indices of the last 'n' layers in the model
                named_params = list(name for name, _ in model.model.classifier.named_parameters())
                layer_indices = list(set([int(name.split('.')[0]) for name in named_params[-classifier_n*2:]]))
            for i, index_val in enumerate(layer_indices):
                params_list.append({'params':model.model.classifier[index_val].parameters(), 'lr': self.lr*(f_fc**i)})


        if self.features_n != -1:
            if self.features_n == 0:
                named_params = list(name for name, _ in model.model.features.named_parameters())
                layer_indices = list(set([int(name.split('.')[0]) for name in named_params]))
            else:
                # get indices of the last 'n' layers in the model
                named_params = list(name for name, _ in model.model.features.named_parameters())
                layer_indices = list(set([int(name.split('.')[0]) for name in named_params[-features_n*2:]]))
            for i, index_val in enumerate(layer_indices):
                params_list.append({'params':model.model.features[index_val].parameters(), 'lr': self.lr*(f_conv**(i+1))})
                
        if self.classifier_n == self.features_n == -1:
            return self.parameters()
                
        return params_list
    
    
    def configure_optimizers(self):
        params_list = self.get_optimizer_params_list()
        optimizer = Adam(params_list, lr = self.lr)
        return optimizer
        

    def training_step(self, batch, batch_idx):
        x, y = batch
        logps = self(x)
        loss = F.nll_loss(logps, y)
        y_pred = torch.argmax(torch.exp(logps), 1)
        acc = (y_pred == y).sum().item()/len(y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss, acc    

    def training_epoch_end(self, outputs):
        num_items = len(outputs)
        cum_loss = 0
        cum_acc = 0
        for i, (loss, acc) in enumerate(outputs):
            cum_loss += loss
            cum_acc += acc
        
        self.history['train_loss'].append(cum_loss/num_items)
        self.history['train_acc'].append(cum_acc/num_items)     
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logps = self(x)
        loss = F.nll_loss(logps, y)
        y_pred = torch.argmax(torch.exp(logps, 1))
        acc = (y_pred == y).sum().item()/len(y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss, acc
    

    def validation_epoch_end(self, outputs):
        num_items = len(outputs)
        cum_loss = 0
        cum_acc = 0
        for i, (loss, acc) in enumerate(outputs):
            cum_loss += loss
            cum_acc += acc
        
        self.history['val_loss'].append(cum_loss/num_items)
        self.history['val_acc'].append(cum_acc/num_items)
        

    def test_step(self, batch, batch_idx):
        x, y = batch
        logps = self(x)
        loss = F.nll_loss(logps, y)
        y_pred = torch.argmax(torch.exp(logps, 1))
        acc = (y_pred == y).sum().item()/len(y)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss, acc

    def test_epoch_end(self, outputs):
        # Define your test epoch end logic here
        pass
    
    # create dataset objects
    def setup(self, stage=None):
        # define transformers
        train_transform = transforms.Compose([
                transforms.Resize(resizing_factor),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomRotation(15),
                transforms.RandomAffine(degrees = 10,
                                        translate = (0.2, 0.2), shear = 10),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])

        test_transform = transforms.Compose([transforms.Resize(resizing_factor),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean, std)])
        
        # create datasets
        train = torchvision.datasets.ImageFolder(self.train_path, transform=train_transform)
        total_items = len(train)
        val_size = int(total_items*0.2)
        train_size = total_items - val_size
        self.train_dataset, self.val_dataset = random_split(train, [train_size, val_size])
        self.test_dataset = torchvision.datasets.ImageFolder(self.test_path, transform=test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, self.batch_size, shuffle=False)


# Instantiate the Lightning Trainer
trainer = pl.Trainer(max_epochs=epochs)

# Instantiate the model with metadata
model = ConvNetPL(pretrained_model_name, pretrained_model_path, num_classes)

# Train the model
trainer.fit(model)
