In [1]:
### Imports ###
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
!pip install SSIM-PIL
from SSIM_PIL import compare_ssim
from PIL import Image
from math import log10, sqrt

### Mount Drive ###
from google.colab import drive
drive.mount('/content/drive')

### Same Generator Settings from .pt file ###
num_blocks = 10
transpose2D = False

### Set path to G_MC weights ###
PATH = "/content/drive/MyDrive/" + "model.pt" # Location of weights
filePath= "/content/drive/MyDrive/"           # Location of MRI scans
fileNames = ["fileName1", "fileName2"]        # filenames of the scans (should be .png and 256x256)

### Device to produce colorizations ###
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cuda


In [2]:
### Generator MC to do forward pass ###
class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.convRes = nn.Conv2d(256,256,3)
        self.reflect1 = nn.ReflectionPad2d(1)
        self.norm256 = nn.InstanceNorm2d(256)
        self.ReLu = nn.ReLU(inplace=True)
    def block(self, x):
        x = self.reflect1(x)
        x = self.ReLu(self.norm256(self.convRes(x)))
        x = self.reflect1(x)
        x = self.norm256(self.convRes(x))
        return x
    def forward(self, x):
        return x + self.block(x)
  
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, 7, stride = 1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride = 2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, stride = 2, padding=1)
        self.block = ResidualBlock()
        self.conv4 = nn.Conv2d(256, 128, 3, stride = 1, padding=1)
        self.conv5 = nn.Conv2d(128, 64, 3, stride = 1, padding=1)
        self.conv6 = nn.Conv2d(64, 3, 7, stride = 1, padding=0)

        self.reflect1 = nn.ReflectionPad2d(1)
        self.reflect3 = nn.ReflectionPad2d(3)
        self.ReLu = nn.ReLU(inplace=True)
        self.tanH = nn.Tanh()

        self.upsample = nn.Upsample(scale_factor=2)
        self.norm256 = nn.InstanceNorm2d(256)
        self.norm128 = nn.InstanceNorm2d(128)
        self.norm64 = nn.InstanceNorm2d(64)

    def forward(self, x):
        x = self.reflect3(x)
        x = self.ReLu(self.norm64 (self.conv1(x))) #1
        x = self.ReLu(self.norm128(self.conv2(x))) #2
        x = self.ReLu(self.norm256(self.conv3(x))) #3
        for i in range(num_blocks): 
            x = self.block(x)
        if transpose2D:
            x = self.ReLu(self.norm128(self.upsample1(x))) #4
            x = self.ReLu(self.norm64(self.upsample2(x))) #5
        else:
            x = self.ReLu(self.norm128(self.conv4(self.upsample(x)))) #4
            x = self.ReLu(self.norm64(self.conv5(self.upsample(x))))  #5    
        x = self.reflect3(x)
        x = self.tanH(self.conv6(x))
        return x

### Load StateDict ###
Generator_MC = Generator().to(device)
Generator_MC .load_state_dict(torch.load(PATH, map_location=torch.device('cpu'))['GMC_state_dict'] )
Generator_MC.eval()

Generator(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (block): ResidualBlock(
    (convRes): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (reflect1): ReflectionPad2d((1, 1, 1, 1))
    (norm256): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (ReLu): ReLU(inplace=True)
  )
  (conv4): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
  (reflect1): ReflectionPad2d((1, 1, 1, 1))
  (reflect3): ReflectionPad2d((3, 3, 3, 3))
  (ReLu): ReLU(inplace=True)
  (tanH): Tanh()
  (upsample): Upsample(scale_factor=2.0, mode=nearest)
  (norm256): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats

In [3]:
for i in fileNames:  
          prefix = filePath
          suffix = ".png"
          path = prefix + str(i) + suffix
          testImage = np.asarray(Image.open(path))
          testImageScaled = (testImage/127.5) -1
          testTensor = torch.tensor(testImageScaled.astype(float).T, requires_grad=False).float().unsqueeze(0).to(device)
          outGenerated = Generator_MC(testTensor)   
          outGenerated = outGenerated[0].detach()
          outGenerated = torch.clip(outGenerated, -1, 1)
          outGenerated = ((outGenerated.cpu().numpy().T+1)*127.5).astype(np.uint8)
          saveName = str(i) + "-color"
          im = Image.fromarray(outGenerated)
          im.save(prefix +  saveName + suffix)
