In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [38]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [39]:
noisy_imgs1, noisy_imgs2 = torch.load('/content/drive/MyDrive/Colab Notebooks/train_data.pkl') # 50000 x 3 x 32 x 32
noisy_val, clean_val = torch.load('/content/drive/MyDrive/Colab Notebooks/val_data.pkl')

In [40]:
def psnr(denoised, ground_truth):
  mse = torch.mean((denoised - ground_truth) ** 2)
  return -10*torch.log10(mse + 10 ** -8)

In [41]:
def validate(model, noise_img, ground_truth):
  psnr_tot = 0
  if torch.cuda.is_available():
      noise_img, ground_truth = noise_img.to("cuda"), ground_truth.to("cuda")
  for i in range(noise_img.size(0)):
    denoised = model(noise_img[i].view(1, 3, 32, 32))
    print(noise_img[i])
    psnr_val = psnr(denoised, ground_truth[i]).item()
    print(psnr_val)
    psnr_tot += psnr_val
  psnr_tot /= noise_img.size(0)
  return psnr_tot

In [5]:
def train_model(model, train_input, train_target, criterion, optimizer, mini_batch_size=4, epochs=500, normalize=False):
    if torch.cuda.is_available():
      model.to("cuda")
      train_input, train_target = train_input.to("cuda"), train_target.to("cuda")
    if normalize:
      mu, std = train_input.mean(), train_input.std()
      train_input.sub_(mu).div_(std)

    for e in range(epochs):
      print(e)
      #TODO ADD PRINT OF CURRENT LOSS TO MAKE SURE THERE IS NO NAN
      for b in range(0, train_input.size(0), mini_batch_size):
          output = model(train_input.narrow(0, b, mini_batch_size))
          loss = criterion(output, train_target.narrow(0, b, mini_batch_size))
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

In [6]:
class SuperbModel(torch.nn.Module):
    def __init__(self, transposed_conv=False):
        super().__init__()
        self.enc_conv0 = nn.Conv2d(3, 48, (3, 3), padding='same')
        self.enc_conv1 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.enc_conv2 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.enc_conv3 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.enc_conv4 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.enc_conv5 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.enc_conv6 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.upsample5 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenation 
        self.dec_conv5a = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.dec_conv5b = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.upsample4 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenation
        self.dec_conv4a = nn.Conv2d(144, 96, (3, 3), padding='same')
        self.dec_conv4b = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.upsample3 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenation
        self.dec_conv3a = nn.Conv2d(144, 96, (3, 3), padding='same')
        self.dec_conv3b = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.upsample2 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenatio
        self.dec_conv2a = nn.Conv2d(144, 96, (3, 3), padding='same')
        self.dec_conv2b = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.upsample1 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenation
        self.custom = nn.Conv2d(144, 99, (3, 3), padding='same')
        self.dec_conv1a = nn.Conv2d(96 + 3, 64, (3, 3), padding='same')
        self.dec_conv1b = nn.Conv2d(64, 32, (3, 3), padding='same')
        self.dec_conv1c = nn.Conv2d(32, 3, (3, 3), padding='same')
                  #what does linear activation mean (output unchanged?)

    def forward(self, x):
      input = x.clone()
      x = F.leaky_relu(self.enc_conv0(x), negative_slope=0.1)
      x = F.leaky_relu(self.enc_conv1(x), negative_slope=0.1)
      #pool1 = F.max_pool2d(x, 2)
      #x = F.leaky_relu(self.enc_conv2(pool1), negative_slope=0.1)
      '''pool2 = F.max_pool2d(x, 2)
      x = F.leaky_relu(self.enc_conv3(pool2), negative_slope=0.1)'''
      pool3 = F.max_pool2d(x, 2) #MAKE SURE THEY HAVE CORRECT VALUES AND THEY ARE NOT CHANGED BY FOLLOWING OPERATIONS
      x = F.leaky_relu(self.enc_conv4(pool3), negative_slope=0.1)
      pool4 = F.max_pool2d(x, 2)
      x = F.leaky_relu(self.enc_conv5(pool4), negative_slope=0.1)
      ''' pool5 = F.max_pool2d(x, 2)
      x = F.leaky_relu(self.enc_conv6(pool5), negative_slope=0.1)
      x = self.upsample5(x)'''
      x = torch.cat((x, pool4), dim=1)#.view(-1, 96, 2, 2) #MAKE SURE THEY ARE STACKED ON THE CORRECT DIMENSIONS
      '''x = F.leaky_relu(self.dec_conv5a(x), negative_slope=0.1)
      x = F.leaky_relu(self.dec_conv5b(x), negative_slope=0.1)
      x = self.upsample4(x) 
      x = torch.cat((x, pool3), dim=1)
      
      x = F.leaky_relu(self.dec_conv4a(x), negative_slope=0.1)
      x = F.leaky_relu(self.dec_conv4b(x), negative_slope=0.1)
      x = self.upsample3(x) 
      x = torch.cat((x, pool2), dim=1)'''
      x = self.upsample3(x) # custom
      x = torch.cat((x, pool3), dim=1) #custom
      x = F.leaky_relu(self.dec_conv3a(x), negative_slope=0.1)
      x = F.leaky_relu(self.dec_conv3b(x), negative_slope=0.1)
      x = self.upsample2(x) 
      #x = torch.cat((x, pool1), dim=1)
      '''
      x = F.leaky_relu(self.dec_conv2a(x), negative_slope=0.1)
      x = F.leaky_relu(self.dec_conv2b(x), negative_slope=0.1)
      x = self.upsample1(x) '''
      x = torch.cat((x, input), dim=1)
      #x = F.leaky_relu(self.custom(x), negative_slope=0.1) #custom
      x = F.leaky_relu(self.dec_conv1a(x), negative_slope=0.1)
      x = self.upsample1(x) #custom
      x = F.leaky_relu(self.dec_conv1b(x), negative_slope=0.1)
      x = F.max_pool2d(x, 2) #custom
      x = self.dec_conv1c(x)
      return x

In [7]:
class LeanerModel(torch.nn.Module):
    def __init__(self, transposed_conv=False):
        super().__init__()
        self.enc_conv0 = nn.Conv2d(3, 48, (3, 3), padding='same')
        self.enc_conv1 = nn.Conv2d(48, 48, (3, 3), padding='same')
        self.enc_conv3 = nn.Conv2d(48, 48, (3, 3), padding='same')
                  # concatenation
        self.upsample2 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenatio
        self.dec_conv2a = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.dec_conv2b = nn.Conv2d(96, 96, (3, 3), padding='same')
        self.upsample1 = nn.UpsamplingNearest2d(scale_factor=2)
                  # concatenation
        self.custom = nn.Conv2d(144, 99, (3, 3), padding='same')
        self.dec_conv1a = nn.Conv2d(96 + 3, 64, (3, 3), padding='same')
        self.dec_conv1b = nn.Conv2d(64, 32, (3, 3), padding='same')
        self.dec_conv1c = nn.Conv2d(32, 3, (3, 3), padding='same')
                  #what does linear activation mean (output unchanged?)

    def forward(self, x):
      input = x.clone()
      x = F.leaky_relu(self.enc_conv0(x), negative_slope=0.1)
      x = F.leaky_relu(self.enc_conv1(x), negative_slope=0.1)
      pool1 = F.max_pool2d(x, 2) #48x16x16
      x = F.leaky_relu(self.enc_conv3(pool1), negative_slope=0.1)
      #pool3 = F.max_pool2d(x, 2) #48x8x8
      #x = self.upsample2(x) #48x32x32
      x = torch.cat((x, pool1), dim=1)#96x16x16
      x = F.leaky_relu(self.dec_conv2a(x), negative_slope=0.1)
      x = F.leaky_relu(self.dec_conv2b(x), negative_slope=0.1)
      x = self.upsample1(x)  #96x32x32
      x = torch.cat((x, input), dim=1) #99x32x32
      x = F.leaky_relu(self.dec_conv1a(x), negative_slope=0.1)
      x = F.leaky_relu(self.dec_conv1b(x), negative_slope=0.1)
      x = self.dec_conv1c(x)
      return x

In [12]:
print(torch.cuda.is_available())

True


In [43]:
model = LeanerModel()
train_model(model, noisy_imgs1.float(), noisy_imgs2.float(), nn.MSELoss(), torch.optim.Adam(model.parameters()), 500, 100)
torch.save(model.state_dict(), 'mymodule.pt')

0


KeyboardInterrupt: ignored

In [42]:
model = LeanerModel()
model.load_state_dict(torch.load('mymodule.pt'))
print(model.state_dict())
#print(validate(model, noisy_val.float(), clean_val.float()))

OrderedDict([('enc_conv0.weight', tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]]],


        [[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]]],


        [[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]]],


        ...,


        [[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
     