# Watermark Detection


## 1. Setup

In [None]:
from tqdm import tqdm 

import os
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
try:
    import torch_xla.core.xla_model as xm
except ImportError:
    xm = None

import torchvision
import numpy as np

%load_ext autoreload
%autoreload 2

os.environ['KMP_DUPLICATE_LIB_OK']='True' # To prevent the kernel from dying.

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = xm.xla_device()

print(device)

## 2. Visualize Data

In [None]:
from project.data.data_module import DataModule

hparams = {
    'learning_rate': 1e-05,
    'batch_size': 64,
    'batch_norm': False,
    'epochs': 20,
    'num_classes': 2,
    'num_workers': 8,
    'dev': device,
    'load_method': 'memory',
}

data_module = DataModule(hparams)

In [None]:
dataiter = iter(data_module.get_valid_dataloader())
for images, labels in dataiter:
    break

In [None]:
idx_to_class = data_module.get_idx_to_class_dict()

In [None]:
from project.utils.images import imshow

fig = plt.figure(figsize=(20,20))

for idx in range(4):
    ax = fig.add_subplot(2, 2, idx+1, xticks=[], yticks=[])
    imshow(images[idx], ax)
    ax.set_title(idx_to_class[int(labels[idx])], fontdict={'fontsize': 20})

## 3. Baseline Model

In [None]:
from project.networks.naive import NaiveModel
from project.utils.models import init_weights, number_of_parameters

naive_model = NaiveModel(hparams=hparams)
naive_model.apply(init_weights)

print('# Parameters: ', number_of_parameters(naive_model))