In [1]:
import wandb
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import random
from torchvision import datasets
import torchvision.transforms as transforms
import os
from torch.utils.data import Dataset
# import natsort
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import OrderedDict 
import datetime
from scipy import special
import cv2

In [20]:
!pip3 install scipy

Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [None]:
from google.colab import drive
drive.mount('/content/gdrive' )
# drive.flush_and_unmount()

In [2]:
wandb.login()

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33makshi11[0m (use `wandb login --relogin` to force relogin)


True

In [2]:
class NoisyImageDataset(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = []
        for f in os.listdir(main_dir):
            if not f.startswith('.'):
                all_imgs.append(f)
        self.total_imgs = all_imgs

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("LA")
        tensor_image = self.transform(image)
        return torch.reshape(tensor_image[0], (1, 256, 256))

In [3]:
def load_data(images_dir_list, L, p=1, batch_size=10):
  # images_dir = 'gdrive/My Drive/ImageNet/noisy/'
  # plants_dir = images_dir + 'plants_L_'
  # animals_dir = images_dir + 'animals_L_'
  # scenery_dir = images_dir + 'scenery_L_'

  data_transforms = transforms.Compose([
                                        transforms.RandomResizedCrop(256),
                                        transforms.ToTensor()])
                                                            
  data_set_list = []
  for li in images_dir_list:
    s = ''
    if L !=0:
      s = str(L)
    data_set_list.append(NoisyImageDataset(li+s, transform=data_transforms))
    data_set_list.append(NoisyImageDataset(li+s, transform=data_transforms))
    data_set_list.append(NoisyImageDataset(li+s, transform=data_transforms))
  data_set = torch.utils.data.ConcatDataset(data_set_list)
  l = int(len(data_set)*p)
  print(l)
  train, valid = torch.utils.data.random_split(data_set, [l, len(data_set)-l])
  trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
  validloader = []
  if len(valid) != 0:
    validloader = torch.utils.data.DataLoader(valid, batch_size=batch_size, shuffle=True, num_workers=2)

  return trainloader, validloader

# Network Architecture and other necessary functions

In [3]:
class UNet(nn.Module):
  def __init__(self, in_channels=1, out_channels=1, init_features=32):
    super(UNet, self).__init__()

    features = init_features
    self.encoder1 = UNet._block(in_channels, features, name="enc1", kernel_size=7, padding=3)
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder2 = UNet._block(features, features * 2, name="enc2", kernel_size=5, padding=2)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
    self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

    self.upconv4 = nn.ConvTranspose2d(
        features * 16, features * 8, kernel_size=2, stride=2
    )
    self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
    self.upconv3 = nn.ConvTranspose2d(
        features * 8, features * 4, kernel_size=2, stride=2
    )
    self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
    self.upconv2 = nn.ConvTranspose2d(
        features * 4, features * 2, kernel_size=2, stride=2
    )
    self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
    self.upconv1 = nn.ConvTranspose2d(
        features * 2, features, kernel_size=2, stride=2
    )
    self.decoder1 = UNet._block(features * 2, features, name="dec1")

    self.conv = nn.Conv2d(
        in_channels=features, out_channels=out_channels, kernel_size=1
    )

  def forward(self, x):
    enc1 = self.encoder1(x)
#     print(enc1.size())
    enc2 = self.encoder2(self.pool1(enc1))
#     print(enc2.size())
    enc3 = self.encoder3(self.pool2(enc2))
#     print(enc3.size())
    enc4 = self.encoder4(self.pool3(enc3))
#     print(enc4.size())

    bottleneck = self.bottleneck(self.pool4(enc4))
#     print(bottleneck.size())
    dec4 = self.upconv4(bottleneck)
#     print(dec4.size())
    dec4 = torch.cat((dec4, enc4), dim=1)
    dec4 = self.decoder4(dec4)
    dec3 = self.upconv3(dec4)
#     print(dec3.size())
    dec3 = torch.cat((dec3, enc3), dim=1)
    dec3 = self.decoder3(dec3)
    dec2 = self.upconv2(dec3)
#     print(dec2.size())
    dec2 = torch.cat((dec2, enc2), dim=1)
    dec2 = self.decoder2(dec2)
    dec1 = self.upconv1(dec2)
#     print(dec1.size())
    dec1 = torch.cat((dec1, enc1), dim=1)
    dec1 = self.decoder1(dec1)
    out = torch.sigmoid(self.conv(dec1))
#     print(out.size())
    return out

  @staticmethod
  def _block(in_channels, features, name, kernel_size=3, padding=1):
    return nn.Sequential(
        OrderedDict(
            [
                (
                    name + "conv1",
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=features,
                        kernel_size=kernel_size,
                        padding=padding,
                        bias=False,
                    ),
                ),
                (name + "norm1", nn.BatchNorm2d(num_features=features)),
                (name + "relu1", nn.ReLU(inplace=True)),
                (
                    name + "conv2",
                    nn.Conv2d(
                        in_channels=features,
                        out_channels=features,
                        kernel_size=kernel_size,
                        padding=padding,
                        bias=False,
                    ),
                ),
                (name + "norm2", nn.BatchNorm2d(num_features=features)),
                (name + "relu2", nn.ReLU(inplace=True)),
            ]
        )
    )


In [3]:
class UNet(nn.Module):
  def __init__(self, in_channels=1, out_channels=1, init_features=64):
    super(UNet, self).__init__()

    features = init_features
    self.encoder1 = UNet._block(in_channels, features, name="enc1")
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder2 = UNet._block(features, features * 2, name="enc2")
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
    self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

    self.upconv4 = nn.ConvTranspose2d(
        features * 16, features * 8, kernel_size=2, stride=2
    )
    self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
    self.upconv3 = nn.ConvTranspose2d(
        features * 8, features * 4, kernel_size=2, stride=2
    )
    self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
    self.upconv2 = nn.ConvTranspose2d(
        features * 4, features * 2, kernel_size=2, stride=2
    )
    self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
    self.upconv1 = nn.ConvTranspose2d(
        features * 2, features, kernel_size=2, stride=2
    )
    self.decoder1 = UNet._block(features * 2, features, name="dec1")

    self.conv = nn.Conv2d(
        in_channels=features, out_channels=out_channels, kernel_size=1
    )

  def forward(self, x):
    enc1 = self.encoder1(x)
#     print(enc1.size())
    enc2 = self.encoder2(self.pool1(enc1))
#     print(enc2.size())
    enc3 = self.encoder3(self.pool2(enc2))
#     print(enc3.size())
    enc4 = self.encoder4(self.pool3(enc3))
#     print(enc4.size())

    bottleneck = self.bottleneck(self.pool4(enc4))
#     print(bottleneck.size())
    dec4 = self.upconv4(bottleneck)
#     print(dec4.size())
    dec4 = torch.cat((dec4, enc4), dim=1)
    dec4 = self.decoder4(dec4)
    dec3 = self.upconv3(dec4)
#     print(dec3.size())
    dec3 = torch.cat((dec3, enc3), dim=1)
    dec3 = self.decoder3(dec3)
    dec2 = self.upconv2(dec3)
#     print(dec2.size())
    dec2 = torch.cat((dec2, enc2), dim=1)
    dec2 = self.decoder2(dec2)
    dec1 = self.upconv1(dec2)
#     print(dec1.size())
    dec1 = torch.cat((dec1, enc1), dim=1)
    dec1 = self.decoder1(dec1)
    out = torch.sigmoid(self.conv(dec1))
#     print(out.size())
    return out

  @staticmethod
  def _block(in_channels, features, name, kernel_size=3, padding=1):
    return nn.Sequential(
        OrderedDict(
            [
                (
                    name + "conv1",
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=features,
                        kernel_size=kernel_size,
                        padding=padding,
                        bias=False,
                    ),
                ),
                (name + "norm1", nn.BatchNorm2d(num_features=features)),
                (name + "relu1", nn.ReLU(inplace=True)),
                (
                    name + "conv2",
                    nn.Conv2d(
                        in_channels=features,
                        out_channels=features,
                        kernel_size=kernel_size,
                        padding=padding,
                        bias=False,
                    ),
                ),
                (name + "norm2", nn.BatchNorm2d(num_features=features)),
                (name + "relu2", nn.ReLU(inplace=True)),
            ]
        )
    )


In [4]:
use_cuda = torch.cuda.is_available()

In [6]:
print(torch.cuda.is_available())

True


In [5]:
wandb.init(project="ssl")
# wandb.login(force=True)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33makshi11[0m (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "
[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


# Optimizers, loss functions and training loop

# Train model for SAR Images

In [6]:
def simulate_speckle(clean_im, L):
    M = np.log(256)
    m = 0
    s = torch.zeros_like(clean_im)
    for k in range(0, L):
        gamma1 = torch.normal(mean=0, std=1, size=clean_im.size())**2
        gamma2 = torch.normal(mean=0, std=1, size=clean_im.size())**2
        s = s + torch.abs(gamma1 + gamma2)
    s_amplitude = torch.sqrt(s / L)
#     print('linear noise',s_amplitude)
    log_speckle = torch.log(s_amplitude)
#     print('log noise',log_speckle)
    log_norm_speckle = log_speckle / (M - m)
#     print('normal log noise',log_norm_speckle)
#     print('clean',clean_im)
    noisy_im = clean_im + log_norm_speckle
#     print('noisy',noisy_im)
    noisy_im = torch.clamp(noisy_im,min=0.0, max = 1.0)
#     print('noisy after clamp',noisy_im)
    return noisy_im

In [6]:
dir_list = './qgissenti'
data_transforms = transforms.Compose([
                                      transforms.ToTensor()])
data_set = NoisyImageDataset(dir_list, data_transforms)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=len(data_set), shuffle=True, num_workers=1)
# num_of_pixels = len(data_set) * 256 * 256
# total_sum = 0
# for batch in dataloader: 
# #     print(batch[0].size())
#     total_sum += batch[0].sum()
# mean = total_sum / num_of_pixels
# sum_of_squared_error = 0
# for batch in dataloader: 
#     sum_of_squared_error += ((batch[0] - mean).pow(2)).sum()
# std = torch.sqrt(sum_of_squared_error / num_of_pixels)

In [7]:
# data = next(iter(dataloader))
# mean, std = data[0].mean(), data[0].std()
# print(mean, std)

In [8]:
data_transforms = transforms.Compose([
                                      transforms.RandomRotation(90),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0.5, 0.5)
                                        ])
dataset = NoisyImageDataset(dir_list, data_transforms)
# dataloaderSAR = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True, num_workers=2)

In [9]:
l = int(len(dataset)*0.9)
train, valid = torch.utils.data.random_split(dataset, [l, len(dataset)-l])
trainloaderSAR = torch.utils.data.DataLoader(train, batch_size=16, shuffle=True, num_workers=2)
validloaderSAR = torch.utils.data.DataLoader(valid, batch_size=10, shuffle=True, num_workers=2)

In [13]:
dirpath = './BSD'
datatransforms = transforms.Compose([ 
                                      transforms.RandomResizedCrop(256),
                                      transforms.ToTensor()])
data_set_train = NoisyImageDataset(dirpath+'/train', datatransforms)
data_set_valid = NoisyImageDataset(dirpath+'/valid', datatransforms)
dataloadertrain = torch.utils.data.DataLoader(data_set_train, batch_size=1, shuffle=True, num_workers=1)
dataloadervalid = torch.utils.data.DataLoader(data_set_valid, batch_size=1, shuffle=True, num_workers=1)
for idx,image in enumerate(dataloadertrain):
#     b_hat = 10.0
#     noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
#     image = image*noise1
    image = lintolog(image)
    image = simulate_speckle(image, 4)
    plt.imsave('./BSDnoisy/train/'+str(idx)+'.jpg',logtolin(image[0][0]), cmap='gray')
for idx,image in enumerate(dataloadervalid):
#     b_hat = 10.0
#     noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
#     image = image*noise1
    image = lintolog(image)
    image = simulate_speckle(image, 4)
    plt.imsave('./BSDnoisy/train/val'+str(idx)+'.jpg',logtolin(image[0][0]), cmap='gray')
#     image = lintolog(image)
#     img = simulate_speckle(image, 10)
#     plt.imsave('./BSDcorr/n2/valid/'+str(idx)+'.jpg',logtolin(img[0][0]), cmap='gray')   



In [7]:

data_transforms = transforms.Compose([
                                      transforms.RandomRotation(90),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0.5, 0.5)
                                        ])
# validset = NoisyImageDataset(dirpath+'/valid', data_transforms)
dataset = NoisyImageDataset('BSDnoisy/train', data_transforms)
l = int(len(dataset)*0.8)
train, valid = torch.utils.data.random_split(dataset, [l, len(dataset)-l])
trainloader = torch.utils.data.DataLoader(train, batch_size=16, shuffle=True, num_workers=2)
validloader = torch.utils.data.DataLoader(valid, batch_size=16, shuffle=True, num_workers=2)

In [None]:
data = next(iter(dataloaderSAR))
mean_new, std_new = data[0].mean(), data[0].std()
print(mean_new, std_new)

In [None]:
model_S = UNet().double()
if use_cuda:
  model_S = model_S.cuda()
print(use_cuda)

In [None]:
def init_weights(m):
  if isinstance(m, nn.Conv2d):
    nn.init.kaiming_normal_(m.weight)

  elif isinstance(m, nn.BatchNorm2d):
    nn.init.constant_(m.weight, 1)
    nn.init.constant_(m.bias, 0)

  elif isinstance(m, nn.Linear):
    nn.init.kaiming_normal_(m.weight)
    nn.init.constant_(m.bias, 0)

In [None]:
model_SAR = model_S.apply(init_weights)

In [None]:
wandb.watch(model_SAR)

[<wandb.wandb_torch.TorchGraph at 0x7f60ba0a2390>]

In [14]:
class SURE_Loss(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, y, y_hat, z_hat, n, x_dim, var, batch_size, eps):
    div = (1/eps)*torch.sum((n * (z_hat - y_hat)))   #.view(-1,y.shape[0]*y.shape[1] )
    sure = (1.0 / batch_size)*(torch.sum((y - y_hat)**2) - (batch_size * x_dim * x_dim * var) + (2*var*div))
    return abs(sure)

In [15]:
def stretch(image, max_val=1, min_val=0):
    image, dtype = to_float(image)
    img_min = np.min(image)
    img_max = np.max(image)
    image_stretched = min_val + (max_val - min_val) * (image - img_min) / (img_max - img_min)
    image_stretched = change_type(image_stretched, dtype)
    return image_stretched
    
def change_type(image, dtype):
    current_dtype = image.dtype
    if current_dtype != dtype:
        image = image.astype(dtype)
    return image

def to_float(image):
    dtype = image.dtype
    if dtype != np.float32:
        image = image.astype(np.float32)
    return image, dtype

In [16]:
def lintolog(image, max_val=1.0, min_val=-1.0):
    image = image.numpy()
    LIN_MAX = 1.0
    LOG_MAX = np.log10(LIN_MAX + 1)
    image, dtype = to_float(image)
    img_min = np.min(image)
    img_max = np.max(image)
    stretch_max = LIN_MAX * (img_max - min_val) / (max_val - min_val)
    stretch_min = LIN_MAX * (img_min - min_val) / (max_val - min_val)
    image = stretch(image, max_val=stretch_max, min_val=stretch_min)
    image = np.log10(image + 1)
    image = image / LOG_MAX
    image = change_type(image, dtype)
    return torch.DoubleTensor(image)

In [17]:
def logtolin(image, max_val=1.0, min_val=0):
    image = image.detach().numpy()
    LIN_MAX = 1.0
    LOG_MAX = np.log10(LIN_MAX + 1)
    image, dtype = to_float(image)
    img_min = np.min(image)
    img_max = np.max(image)
    stretch_max = LOG_MAX * (img_max - min_val) / (max_val - min_val)
    stretch_min = LOG_MAX * (img_min - min_val) / (max_val - min_val)
    image = stretch(image, max_val=stretch_max, min_val=stretch_min)
    image = np.power(10, image) - 1
    image = image / LIN_MAX 
    image = change_type(image, dtype)
    return torch.DoubleTensor(image)

In [16]:
criterion = nn.MSELoss()

In [18]:
optimizer = optim.Adam(model_SAR.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 30, eta_min=0, last_epoch=-1)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [19]:
criterion = SURE_Loss()

In [16]:
print(special.polygamma(1,15))

0.0689382278476838


In [19]:
def train_SAR(n_epochs, trainloader, model, optimizer, criterion, use_cuda, save_path, lr=0.01):
  """returns trained model"""
  # n_epochs = 20
  # trainloader = dataloader_25
  # save_path = 'gdrive/My Drive/Checkpoints_UNet/checkpoint_imnet.pt'
  valid_loss_min = np.Inf
  train_loss_min = np.Inf
  for epoch in range(1, n_epochs+1):
    
    # initialize variables to monitor training and validation loss

    # monitor time
    current_time = datetime.datetime.now()
    print(current_time) 
    train_loss = 0.0
    valid_loss = 0.0
    # accuracy = 0.0
    ###################
    # train the model #
    ###################
    model.train()
    print(optimizer.param_groups[0]['lr'])
    wandb.log({"lr": optimizer.param_groups[0]['lr']})
    for batch_idx, image in enumerate(trainloader):
      # move to GPU
#       print(1.0/torch.var(image))
      print(batch_idx, end=" ")
      image = image.double()
#       b_hat = 100.0
#       noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
# #       noise2 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
#       noisy1 = image*noise1
# #       noisy2 = image*noise2
      image = lintolog(image)
#       noisy1 = lintolog(noisy1)
#       noisy2 = lintolog(noisy2)
#       plt.imshow(lintolog(noise1)[0][0],cmap='gray')
#       noisy1 = noisy1 + lintolog(noise1)
#       print(max(noisy1[0][0][0]))
#       noisy2 = torch.clamp(noisy2 + lintolog(noise2), min=0, max=1)
#       plt.imshow(noisy1[0][0],cmap='gray')
      if use_cuda:
        image = image.cuda()
#         noisy1 = noisy1.cuda()
#         noisy2 = noisy2.cuda()
      ## find the loss and update the model parameters accordingly
      ## record the average training loss, using something like
      ## train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
      optimizer.zero_grad()
      y_hat = model(image)
      # y_hat = forward_NoiseEST(image, x_hat, use_cuda)
      # print('y_hat calculated')
      # loss = criterion(y_hat, image)
      eps = 0.000001
      x_dim = 256
      n = torch.DoubleTensor(np.random.normal(0, 1, size=image.size()))
      img_z = image + torch.DoubleTensor(n*eps).cuda()
      z_hat = model(img_z)
      sigma = 1/torch.var(image)
#       print(sigma)
      var = (sigma/255.0)**2
      loss = criterion(image, y_hat, z_hat, n.cuda(), x_dim, var, 10, eps)
      loss.backward()
      # print('loss calculated')
      optimizer.step()
      # print('backpropagation done')
      train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
      wandb.log({"Running Loss": loss.data})
      
    
#     scheduler.step()
    # model.eval()
    # for batch_idx, image in enumerate(validloader):
    #   image = image.double()
    #   if use_cuda:
    #     image = image.cuda()

    #   x_hat = model(image)
    #   y_hat = forward_NoiseEST(image, x_hat, use_cuda)
    #   loss = criterion(y_hat, image)
    #   valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.data - valid_loss))
    #   wandb.log({"Valid Loss": valid_loss})
    print()   
    
    print('Epoch: {} \tTraining Loss: {:.6f} '.format(
        epoch, 
        train_loss,
        # valid_loss
        ))
    wandb.log({"Train Loss": train_loss})
    if train_loss <= train_loss_min:
        print('Train loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(train_loss_min, train_loss))
        torch.save(model.state_dict(), save_path)
#         wandb.save(save_path)
        train_loss_min = train_loss
  return model

In [14]:
def train_SAR_N2N_valid(n_epochs, trainloader, validloader, model, optimizer, criterion, use_cuda, save_path, lr=0.0001):
  """returns trained model"""
  # n_epochs = 20
  # trainloader = dataloader_25
  # save_path = 'gdrive/My Drive/Checkpoints_UNet/checkpoint_imnet.pt'
  valid_loss_min = np.Inf
  train_loss_min = np.Inf
  for epoch in range(1, n_epochs+1):
#     if epoch%10 == 0:
#         lr = lr/10.0
#         optimizer = optim.Adam(model.parameters(), lr=lr)
    # initialize variables to monitor training and validation loss

    # monitor time
    current_time = datetime.datetime.now()
    print(current_time) 
    train_loss = 0.0
    valid_loss = 0.0
    # accuracy = 0.0
    ###################
    # train the model #
    ###################
    model.train()
    print(optimizer.param_groups[0]['lr'])
    wandb.log({"lr": optimizer.param_groups[0]['lr']})
    for batch_idx, image in enumerate(trainloader):
      # move to GPU
      print(batch_idx, end=" ")
      image = image.double()
#       image = lintolog(image)
      b_hat = 10.0
      noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
#       noise2 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
      
      y1 = image*noise1
#       y2 = image*noise2
#       y1 = torch.clamp(y1, max=0.0, min=1.0)
#       y2 = torch.clamp(y2, max=0.0, min=1.0)
      image = lintolog(image)
#       y1 = simulate_speckle(image, 10)
#       y2 = simulate_speckle(image, 10)
#       print(torch.var(y1), torch.var(y2))
#       y1 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
#       y2 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
      y1 = lintolog(y1)
#       y2 = lintolog(y2)
#       w = (y1+y2)/2
#       v = w+torch.DoubleTensor(np.random.normal(0, torch.var(w), size=w.size()))
#       w = (noisy1+noisy2)/2
#       print(torch.sqrt(torch.var(noisy2)))
#       print(torch.sqrt(torch.var(w)))
#       plt.imshow(lintolog(noise1)[0][0],cmap='gray')
#       noisy1 = noisy1 + lintolog(noise1)
#       print(max(noisy1[0][0][0]))
#       noisy2 = torch.clamp(noisy2 + lintolog(noise2), min=0, max=1)
#       plt.imshow(noisy1[0][0],cmap='gray')
      if use_cuda:
#         v = v.cuda()
#         w = w.cuda()
        image = image.cuda()
#         w = w.cuda()
        y1 = y1.cuda()
#         y2 = y2.cuda()
    

      ## find the loss and update the model parameters accordingly
      ## record the average training loss, using something like
      ## train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
      optimizer.zero_grad()
      y_hat = model(y1)
      # y_hat = forward_NoiseEST(image, x_hat, use_cuda)
      # print('y_hat calculated')
      # loss = criterion(y_hat, image)
#       print(torch.var(image))
#       print(torch.var(noisy1))
#       var = 1.64493407/(np.log(256)**2)
#       var = 0.0014
#       sigma = 10.0
#       eps = 1.6*sigma*0.0001
#       x_dim = 256
#       var = (sigma/255.0)**2
# #       print(torch.var(w))
#       n = torch.DoubleTensor(np.random.normal(0, 1, size=w.size()))
#       neps = torch.DoubleTensor(n*eps).cuda()
#       img_z = v + neps
#       z_hat = model(img_z)
      
#       print(sigma)
      
      loss = criterion(image, y_hat)
      loss.backward()
      # print('loss calculated')
      optimizer.step()
      # print('backpropagation done')
      train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
      wandb.log({"Loss": train_loss})
      
    
    scheduler.step()
    model.eval()
    for batch_idx, image in enumerate(validloader):
      image = image.double()
#       image = lintolog(image)
      b_hat = 10.0
      noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
#       noise2 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
      
      y1 = image*noise1
      image = lintolog(image)
#       y2 = image*noise2
#       y1 = torch.clamp(y1, max=0.0, min=1.0)
#       y2 = torch.clamp(y2, max=0.0, min=1.0)
      y1 = lintolog(y1)
#       y2 = lintolog(y2)
      if use_cuda:
        y1 = y1.cuda()
#         y2 = y2.cuda()
        image = image.cuda()
#         noisy1 = noisy1.cuda()
      y_hat = model(y1)
#       var = 0.0014
#       sigma = 10.0
#       eps = 1.6*sigma*0.0001
#       var = (sigma/255.0)**2
#       x_dim = 256
#       n = torch.DoubleTensor(np.random.normal(0, 1, size=w.size()))
#       neps = torch.DoubleTensor(n*eps).cuda()
#       img_z = v + neps
#       z_hat = model(img_z)
      loss = criterion(image, y_hat)
      valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.data - valid_loss))
      
    print()   
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValid Loss: {:.6f}'.format(
        epoch, 
        train_loss,
        valid_loss
        ))
    wandb.log({"Train Loss": train_loss, 
              "Valid Loss": valid_loss})
    if valid_loss <= valid_loss_min:
        print('Valid loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min, valid_loss))
        torch.save(model.state_dict(), save_path)
#         wandb.save(save_path)
        valid_loss_min = valid_loss
  return model

In [20]:
def train_SAR_ESURE_valid(n_epochs, trainloader, validloader, model, optimizer, criterion, use_cuda, save_path, lr=0.0001):
  """returns trained model"""
  # n_epochs = 20
  # trainloader = dataloader_25
  # save_path = 'gdrive/My Drive/Checkpoints_UNet/checkpoint_imnet.pt'
  valid_loss_min = np.Inf
  train_loss_min = np.Inf
  for epoch in range(1, n_epochs+1):
#     if epoch%10 == 0:
#         lr = lr/10.0
#         optimizer = optim.Adam(model.parameters(), lr=lr)
    # initialize variables to monitor training and validation loss

    # monitor time
    current_time = datetime.datetime.now()
    print(current_time) 
    train_loss = 0.0
    valid_loss = 0.0
    # accuracy = 0.0
    ###################
    # train the model #
    ###################
    model.train()
    print(optimizer.param_groups[0]['lr'])
    wandb.log({"lr": optimizer.param_groups[0]['lr']})
    for batch_idx, image in enumerate(trainloader):
      # move to GPU
      print(batch_idx, end=" ")
      image = image.double()
      image = lintolog(image)
#       b_hat = 10.0
#       noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
#       noise2 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
      
#       y1 = image*noise1
#       y1 = torch.clamp(y1, max=0.0, min=1.0)
#       y2 = image*noise2
      
#       image = lintolog(image)
#       y1 = simulate_speckle(image, 10)
#       y2 = simulate_speckle(image, 10)
#       print(torch.var(y1), torch.var(y2))
#       y1 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
#       y2 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
#       y1 = lintolog(y1)
      y1 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
#       y2 = lintolog(y2)
#       w = (y1+y2)/2
#       v = w+torch.DoubleTensor(np.random.normal(0, torch.var(w), size=w.size()))
#       w = (noisy1+noisy2)/2
#       print(torch.sqrt(torch.var(noisy2)))
#       print(torch.sqrt(torch.var(w)))
#       plt.imshow(lintolog(noise1)[0][0],cmap='gray')
#       noisy1 = noisy1 + lintolog(noise1)
#       print(max(noisy1[0][0][0]))
#       noisy2 = torch.clamp(noisy2 + lintolog(noise2), min=0, max=1)
#       plt.imshow(noisy1[0][0],cmap='gray')
      if use_cuda:
#         v = v.cuda()
#         w = w.cuda()
        image = image.cuda()
        y1 = y1.cuda()

      ## find the loss and update the model parameters accordingly
      ## record the average training loss, using something like
      ## train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
      optimizer.zero_grad()
      y_hat = model(y1)
      # y_hat = forward_NoiseEST(image, x_hat, use_cuda)
      # print('y_hat calculated')
      # loss = criterion(y_hat, image)
#       print(torch.var(image))
#       print(torch.var(noisy1))
#       var = 1.64493407/(np.log(256)**2)
#       var = 0.0014
      sigma = 15.0
      eps = 1.6*sigma*0.0001
      x_dim = 256
      var = (sigma/255.0)**2
#       print(torch.var(w))
      n = torch.DoubleTensor(np.random.normal(0, 1, size=y1.size()))
      neps = torch.DoubleTensor(n*eps).cuda()
      img_z = y1 + neps
      z_hat = model(img_z)
      
#       print(sigma)
      
      loss = criterion(image, y_hat, z_hat, n.cuda(), x_dim, var, 16, eps)
      loss.backward()
      # print('loss calculated')
      optimizer.step()
      # print('backpropagation done')
      train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
      wandb.log({"Loss": train_loss})
      
    
    scheduler.step()
#     model.eval()
#     for batch_idx, image in enumerate(validloader):
#       image = image.double()
#       image = lintolog(image)
# #       b_hat = 10.0
# #       noise1 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
# #       noise2 = torch.from_numpy(np.random.gamma(b_hat, 1/b_hat, image.shape))
# #       y1 = image*noise1
# #       y1 = torch.clamp(y1, max=0.0, min=1.0)
# #       y2 = image*noise2
# #       image = lintolog(image)
# #       y1 = simulate_speckle(image, 10)
# #       y2 = simulate_speckle(image, 10)
# #       noisy1 = lintolog(noisy1)
#       y1 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
# #       y2 = image + torch.from_numpy(np.random.normal(0, 0.1, image.shape))
# #       y1 = lintolog(y1)
# #       y2 = lintolog(y2)
# #       w = (y1+y2)/2
# #       v = w+torch.DoubleTensor(np.random.normal(0, torch.var(w), size=w.size()))
#       if use_cuda:
# #         v = v.cuda()
# #         w = w.cuda()
#         image = image.cuda()
#         y1 = y1.cuda()
#       y_hat = model(y1)
# #       var = 0.0014
#       sigma = 5.1
#       eps = 1.6*sigma*0.0001
#       var = (sigma/255.0)**2
#       x_dim = 256
#       n = torch.DoubleTensor(np.random.normal(0, 1, size=y1.size()))
#       neps = torch.DoubleTensor(n*eps).cuda()
#       img_z = y1 + neps
#       z_hat = model(img_z)
#       loss = criterion(image, y_hat, z_hat, n.cuda(), x_dim, var, 10, eps)
#       valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.data - valid_loss))
    
    print()   
    
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(
        epoch, 
        train_loss,
#         valid_loss
        ))
    wandb.log({"Train Loss": train_loss, 
#               "Valid Loss": valid_loss
              })
    if train_loss <= train_loss_min:
        print('Train loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(train_loss_min, train_loss))
        torch.save(model.state_dict(), save_path)
#         wandb.save(save_path)
        train_loss_min = train_loss
  return model

In [None]:
model_trained = train_SAR_ESURE_valid(30, trainloaderSAR, validloaderSAR, model_SAR, optimizer, criterion, use_cuda, './checkpoints/checkpoint_SAR.pt')