In [7]:
import torch 
import os 
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch.nn.functional as F


In [9]:
# Constants 

pretrained_model = 'ckpt/pretrained_classifier_stroma+tumor.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N_CLASSES = 2

In [6]:
pretrained_resnet50 = torch.load(pretrained_model)

# Freeze Weights and remove last layer 
for param in pretrained_resnet50.parameters():
    param.requires_grad = False
    
encoder = torch.nn.Sequential(*list(pretrained_resnet50.children())[:-1])
encoder

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [8]:
class UNet(torch.nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.decoder1 = torch.nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2)
        self.decoder2 = torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder3 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv1 = torch.nn.Conv2d(2048, 1024, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.conv3 = torch.nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.conv4 = torch.nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.output_conv = torch.nn.Conv2d(128, num_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.encoder[0](x)
        x2 = self.encoder[1](x1)
        x3 = self.encoder[2](x2)
        x4 = self.encoder[3](x3)
        x5 = self.encoder[4](x4)
        x5 = F.adaptive_avg_pool2d(x5,output_size=(14 ,14))
        
        x = self.decoder1(x5)
        x = torch.cat([x4,x], dim=1)
        x = F.relu(self.conv1(x))
        
        x = self.decoder2(x)
        x = torch.cat([x3,x], dim=1)
        x = F.relu(self.conv2(x))
        
        x = self.decoder3(x)
        x = torch.cat([x2,x], dim=1)
        x = F.relu(self.conv3(x))
        
        x = self.decoder4(x)
        x = torch.cat([x1,x], dim=1)
        x = F.relu(self.conv4(x))
        
        x = self.output_conv(x)
        
        return x

In [11]:
unet_model = UNet(encoder,num_classes=N_CLASSES).to(device)
unet_model

UNet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256