In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os

# these next packages come from within our file structure
import import_ipynb

from models.Deeplab.resnet import ResNet18_OS16, ResNet34_OS16, ResNet50_OS16, ResNet101_OS16, ResNet152_OS16, ResNet18_OS8, ResNet34_OS8
from models.Deeplab.aspp import ASPP, ASPP_Bottleneck

class DeepLabV3(nn.Module):
    # def __init__(self, model_id, project_dir): # could maybe remove these arguments now as we're not using them, its taken care of in the wider trainer script 
    def __init__(self):  
        super(DeepLabV3, self).__init__()

        self.num_classes = 20

        #self.model_id = model_id
        #self.project_dir = project_dir
        #self.create_model_dirs()

        self.resnet = ResNet18_OS8() # NOTE! specify the type of ResNet here
        # \/ does this have to change if weve added an extra output to asap module? or just line 35
        self.aspp = ASPP(num_classes=self.num_classes) # NOTE! if you use ResNet50-152, set self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) instead

    def forward(self, x):
        # (x has shape (batch_size, 3, h, w))

        h = x.size()[2]
        w = x.size()[3]

        feature_map = self.resnet(x) # (shape: (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8). If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16))

        output, bias_fork = self.aspp(feature_map) # (shape: (batch_size, num_classes, h/16, w/16))

        output = F.upsample(output, size=(h, w), mode="bilinear") # (shape: (batch_size, num_classes, h, w))  ## By doing this we ensure that the output is the same dimensions as the input which is the same as the label_images becuase we made it like that in the dataloaders. 

        softmax = F.softmax(output, dim=1) # dim = 1 so we softmax along the num_classes

        return output, softmax, bias_fork

    # def create_model_dirs(self):
    #     self.logs_dir = self.project_dir + "/training_logs"
    #     self.model_dir = self.logs_dir + "/model_%s" % self.model_id
    #     self.checkpoints_dir = self.model_dir + "/checkpoints"
    #     if not os.path.exists(self.logs_dir):
    #         os.makedirs(self.logs_dir)
    #     if not os.path.exists(self.model_dir):
    #         os.makedirs(self.model_dir)
    #         os.makedirs(self.checkpoints_dir)


importing Jupyter notebook from resnet.ipynb
importing Jupyter notebook from aspp.ipynb
