In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torchvision import transforms
import torchvision
from torch.utils.data import Dataset, DataLoader 
from tqdm.notebook import tqdm

In [None]:
!unzip /content/drive/MyDrive/data.zip # dataset used to train, pls change path as per instance
clear_output()

In [None]:
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
%matplotlib inline
def denorm(img_tensors): 
    return img_tensors * stats[1][0] + stats[0][0]
def show_images(image): 
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(image.detach()), nrow=1).permute(1, 2, 0))

In [None]:
import os
path1 = "/content/wm-nowm/train/no-watermark"
path2 = "/content/wm-nowm/train/watermark"
def findCommonDeep(path1, path2):
    return list(set.intersection(*(set(file for _, _, files in os.walk(path) for file in files) for path in (path1, path2))))

common = findCommonDeep(path1, path2)

In [None]:
common[0]

In [None]:
class CreateDataset(torch.utils.data.Dataset):
  def __init__(self, img_path, clean_img_path, len):
    self.len = len
    self.transforms = torchvision.transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize((512,512)),
                torchvision.transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5),
                                      (0.5, 0.5, 0.5))
            ]
        )
    self.img_path = img_path
    self.clean_img_path = clean_img_path

  def __getitem__(self, index):
    watermarked_img = torchvision.io.read_image(self.img_path + "/" + common[index])
    watermarked_img = self.transforms(watermarked_img)
    clean_img = torchvision.io.read_image(self.clean_img_path + "/" + common[index])
    clean_img = self.transforms(clean_img)
    return watermarked_img, clean_img

  def __len__(self):
    return self.len

In [None]:
class UBlock(nn.Module):
  def __init__(self, in_channels=3, out_channels=3):
    super(UBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.layer1 = nn.Sequential(
        nn.Conv2d(self.in_channels, 48, kernel_size = 3, stride = 1, padding=1), 
        nn.ReLU())
    self.layer2 = nn.Sequential(
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 1)
    )
    self.layer3 = nn.Sequential( 
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 1))
    self.layer4 = nn.Sequential(
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 1)
    )
    self.layer5 = nn.Sequential(
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 1)
    )
    self.layer6 = nn.Sequential(
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 1)
    )
    self.layer7 = nn.Sequential(
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 1)
    )
    self.g1 = nn.Sequential(
        nn.Conv2d(48, 48, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.ConvTranspose2d(48, 48, kernel_size = 3, stride = 1)     
    )
    self.b1 = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU()
    )
    self.g2 = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.ConvTranspose2d(96, 96, kernel_size = 3, stride = 1)
    )
    self.b2 = nn.Sequential(
        nn.Conv2d(144, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU()
    )
    self.g3 = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.ConvTranspose2d(96, 96, kernel_size = 3, stride = 1)
    )
    self.b3 = nn.Sequential(
        nn.Conv2d(144, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU()
    )
    self.g4 = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.ConvTranspose2d(96, 96, kernel_size = 3, stride = 1)
    )
    self.b4 = nn.Sequential(
        nn.Conv2d(144, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU()
    )
    self.g5 = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.ConvTranspose2d(96, 96, kernel_size = 3, stride = 1)
    )
    self.b5 = nn.Sequential(
        nn.Conv2d(144, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU()
    )
    self.g6 = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.ConvTranspose2d(96, 96, kernel_size = 3, stride = 1)
    )
    self.b6 = nn.Sequential(
        nn.Conv2d(99, 64, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU()
    )
    self.layer_out = nn.Sequential(
        nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.Conv2d(32, 3, kernel_size = 3, stride = 1, padding = 1),
        nn.LeakyReLU()
    )

  def forward(self, x):
    residual1 = x
    x = self.layer1(x)
    x = self.layer2(x)
    residual2 = x
    x = self.layer3(x)
    residual3 = x
    x = self.layer4(x)
    residual4 = x
    x = self.layer5(x)
    residual5 = x
    x = self.layer6(x)
    residual6 = x
    x = self.layer7(x)
    x = self.g1(x)
    x = self.b1(torch.cat((x, residual6),0))
    x = self.g2(x)
    x = self.b2(torch.cat((x, residual5),0))
    x = self.g3(x)
    x = self.b3(torch.cat((x, residual4),0))
    x = self.g4(x)
    x = self.b4(torch.cat((x, residual3),0))
    x = self.g5(x)
    x = self.b5(torch.cat((x, residual2),0))
    x = self.g6(x)
    x = self.b6(torch.cat((x, residual1),0))
    x = self.layer_out(x)
    return x 

In [None]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.block1 = UBlock()
    self.block2 = UBlock()
  
  def forward(self, img):
    out = self.block1(img)
    print(out.shape)
    out = self.block2(out)
    return out

In [None]:
img_path = "/content/wm-nowm/train/watermark"
clean_img_path = "/content/wm-nowm/train/no-watermark"

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
loss_function = nn.L1Loss()
model = Model()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4, betas=(0.5, 0.999))
T_max = 30
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)

In [None]:
model.load_state_dict(torch.load("/content/model (8).pth")) # loading least loss model at time of submission of form 

In [None]:
def train(watermarked_image, clean_image, loss_function, optimizer):
  optimizer.zero_grad()
  pred_img = model(watermarked_image) # getting predictions
  loss = loss_function(pred_img, clean_image) # calculating the loss
  loss.backward() # back propogating
  optimizer.step()
  return pred_img, loss.item()

In [None]:
train_dataset = CreateDataset(img_path, clean_img_path, len(common))
train_dataloader = DataLoader(dataset = train_dataset, batch_size=16, shuffle=True, pin_memory=True)

In [None]:
test_dataset = CreateDataset(img_path, clean_img_path, len(common))
test_dataloader = DataLoader(dataset = test_dataset, batch_size=16, shuffle=True, pin_memory=True)

In [None]:
def fit(epochs, optimizer, loss_function):
  losses = []
  for epoch in range(epochs):
      with torch.autograd.set_detect_anomaly(True):
        for  i,(watermarked_img, clean_img) in enumerate(tqdm(train_dataloader)):
          watermarked_img = watermarked_img.to(device)
          clean_img = clean_img.to(device)
          pred_img , train_loss = train(watermarked_img, clean_img, loss_function, optimizer)
          losses.append(train_loss)
      print("Epoch [{}/{}], train_loss: {:.4f}".format(epoch+1, epochs, train_loss))
  return losses

In [None]:
def evaluate(epochs, optimizer, loss_function):
  losses = []
  preds = []
  actual = []
  for epoch in range(epochs):
    for i, (watermarked_img,clean_img) in enumerate(tqdm(test_dataloader)):
          watermarked_img = watermarked_img.to(device)
          clean_img = clean_img.to(device)
          pred_img , test_loss = train(watermarked_img, clean_img, loss_function, optimizer)
          losses.append(test_loss)
          preds.append(pred_img)
          actual.append(clean_img)
    print("Epoch [{}/{}], val_loss: {:.4f}".format(epoch+1, epochs, test_loss))
  return losses, preds, actual

In [None]:
fit(epochs=100, optimizer=optimizer, loss_function=loss_function)

In [None]:
evaluate(epochs=100, optimizer=optimizer, loss_function = loss_function)

In [None]:
# evaluation of test images

watermarked_img = torchvision.io.read_image("/content/PS2_test_image_7.jpg")
watermarked_img = transforms.ToPILImage()(watermarked_img)
set_transforms = torchvision.transforms.Compose(
            [
                transforms.Resize((512,512)),
                torchvision.transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5),
                                      (0.5, 0.5, 0.5))
            ]
        )
watermarked_img = set_transforms(watermarked_img).to(device)
pred = model(watermarked_img)
show_images(pred.cpu())