In [1]:
import sys
path = "../../.."
if path not in sys.path:
    sys.path.insert(0, path)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from data_retrieval import lipade_groundtruth
from data_retrieval.tools.data_loader import getDataLoader
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

optimizerFunc = optim.Adam
temperature = 0.5
learningRate = 1e-3
batch_size = 512
workers = 2
corpus = "lipade_groundtruth"
resultsPath = "../results/distance/" + corpus + "/"


## Dataset

In [4]:
x,_,y = lipade_groundtruth.getDataset(mode = 'unique', uniform=True)

images = []
for i in range(len(x)):
    images.append(Image.open(x[i]).convert('RGB'))

In [5]:
x = np.array(x)
y = np.array(y)
images = np.array(images)

images.shape

(5855, 256, 256, 3)

In [6]:
trainLoader = getDataLoader(images, None, None, False, batch_size, True, num_workers=2)

## Transformations

In [9]:
from degradations.methods.halftoning.floyd_steinberg import transforms_floyd_steinberg_halftoning
from degradations.methods.halftoning.atkinson import transforms_atkinson_dithering
from degradations.methods.halftoning.bayers_threshold import transforms_bayer_halftoning
from degradations.methods.halftoning.dot_traditional import transforms_dot_halftoning  # Import your halftoning methods
from degradations.methods.noise.gaussian_noise import transforms_add_gaussian_noise
from degradations.methods.noise.salt_and_pepper import transforms_add_salt_and_pepper_noise
from degradations.methods.noise.dirty_rollers import transforms_dirty_rollers
#from degradations.methods.noise.film_grain import transforms_apply_film_grain # Import your noise methods
from degradations.methods.paper.ink_bleed import transforms_ink_bleed  
from degradations.methods.paper.crumpled_paper import transforms_crumpled_paper
from degradations.methods.paper.folded_paper import transforms_folded_paper
from degradations.methods.paper.bleedthrough import transforms_bleedthrough
from degradations.methods.paper.scribbles import transforms_scribbles
from degradations.methods.paper.stains import transforms_stains # Import your paper feel methods
from degradations.methods.human_corrections.erased_element import transforms_erased_element # Import your human correction methods
from degradations.methods.layout.picture_overlay import transforms_picture_overlay
from degradations.methods.layout.text_overlay import transforms_text_overlay # Import your layout methods


class transforms_SepiaFilter(nn.Module):
    def __init__(self):
        super(transforms_SepiaFilter, self).__init__()

    def __call__(self, batch):
        sepia_filter = torch.tensor([[0.393, 0.769, 0.189],
                                     [0.349, 0.686, 0.168],
                                     [0.272, 0.534, 0.131]], device=batch.device)
        batch = torch.einsum('ijkl,mj->imkl', batch, sepia_filter)
        return batch.clamp(0, 1)


class transforms_Rotate(nn.Module):
    def __init__(self):
        super(transforms_Rotate, self).__init__()

    def __call__(self, batch):
        batch = batch.movedim(2,3)    
        return batch

In [10]:
transform_degrad = transforms.Compose([
    transforms.RandomApply([
        transforms.RandomResizedCrop(size=images.shape[2], scale=(2/3, 1), ratio=(1, 1))
    ], p=1/3),

    # Sepia
    transforms.RandomApply([transforms_SepiaFilter()],p=1/3),

    # halftone
    transforms.RandomApply([
        transforms.RandomChoice([
            transforms_floyd_steinberg_halftoning(128),
            transforms_atkinson_dithering(128)
            ])
        ], p=0.05),

    # layout
    transforms.RandomApply([
            transforms.RandomChoice([
                transforms_picture_overlay(),
                transforms_text_overlay()
            ])
    ], p=0.2),

    # erased
    transforms.RandomApply([
        transforms_erased_element()
    ], p=0.1),

    # noise
    transforms.RandomApply([
            transforms.RandomChoice([
                transforms_add_gaussian_noise(),
                transforms_add_salt_and_pepper_noise(),
                transforms_dirty_rollers((8,10))
            ])
    ], p=0.1),

    # stains
    transforms.RandomApply([
        transforms.RandomChoice([
            #transforms_scribbles(),
            transforms_stains(), 
            transforms_ink_bleed(),
            transforms_bleedthrough(),
        ])
    ], p=0.3),

    # texture
    transforms.RandomApply([
        transforms.RandomChoice([
            transforms_crumpled_paper(),
            transforms_folded_paper(0.4),   
        ])
    ], p=0.2),

    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2), 
    transforms.GaussianBlur(kernel_size=9)
])


## Mod√®le

In [13]:
class DINO_Representation(nn.Module):
    def __init__(self, encoder, in_dim=2048, out_dim=128, m=0.99):
        super(DINO_Representation, self).__init__()

        # Student encoder
        self.encoder = encoder

        # Projection head for student
        self.projection_head = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

        # Momentum encoder (teacher)
        self.momentum_encoder = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.momentum_encoder.fc = nn.Identity()

        # Set momentum encoder to eval mode and frozen
        self.momentum_encoder.eval()
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False

        # Momentum update coefficient
        self.m = m

    def forward(self, x, x_2):
        student_repr = self.encoder(x_2)
        student_z = self.projection_head(student_repr)
        
        # Apply the momentum update rule to the teacher
        with torch.no_grad():
            # Momentum update of the teacher encoder
            for param_q, param_k in zip(self.encoder.parameters(), self.momentum_encoder.parameters()):
                param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
        
        teacher_repr = self.momentum_encoder(x)
        teacher_z = self.projection_head(teacher_repr)
        
        return student_z, teacher_z


In [14]:
representationEncoder = resnet18(weights=ResNet18_Weights.DEFAULT)
representationEncoder.fc = nn.Identity()

model = DINO_Representation(representationEncoder).to(device)

## Loss

In [15]:
class DINOLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(DINOLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, student_z, teacher_z):
        # Normalize the projections to unit length
        student_z = F.normalize(student_z, dim=-1, p=2)
        teacher_z = F.normalize(teacher_z, dim=-1, p=2)

        # Compute the DINO loss (cross-entropy loss between student and teacher)
        loss = -(student_z * teacher_z.detach()).sum(dim=-1).mean()
        return loss


## Training

In [16]:
criterion = DINOLoss(temperature=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
epochs = 100

writer = SummaryWriter(log_dir="logs/DINO")


losses_all = []
for epoch in range(epochs):  
    model.train()
    losses = []

    for x in tqdm(trainLoader, desc="Epoch " + str(epoch)):
        x = x.to(device)
        x_2 = transform_degrad(x)

        optimizer.zero_grad()

        # Forward pass
        student_z, teacher_z = model(x,x_2)

        # Compute the loss
        loss = criterion(student_z, teacher_z)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    writer.add_scalar("Loss/train", torch.tensor(losses).mean(), epoch)
    losses_all.append(torch.tensor(losses).mean())
    torch.save(model.state_dict(), "model_dino.pth")

plt.plot(losses_all)
plt.show()