In [None]:
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms
import random
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR as lr_scheduler
import matplotlib.pyplot as plt

In [None]:
celebA_path = "/home/tomi64/projects/melytanulas_project/celebA"


celebA = torchvision.datasets.CelebA(root = celebA_path, download=True)

In [None]:
IMAGE_SIZE=112
BATCH_SIZE=24

In [None]:
class DataWithNoise(torch.utils.data.Dataset):
    def __init__(self, dataset, transforms, timesteps):
        self.dataset = dataset
        self.transforms = transforms
        self.timesteps = timesteps
        self.beta_schedule = torch.linspace(1e-4, 2e-2, timesteps)
        #self.alpha_schedule = torch.cumprod(1-self.beta_schedule, dim=0)
        self.alpha_schedule = torch.cos(torch.linspace(0, 3.1415/2, timesteps))
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, _ = self.dataset[idx]

        image = self.transforms(image)
        
        noise = torch.randn_like(image)
        t = random.randint(0, self.timesteps-1)
        noisy_image = torch.sqrt(self.alpha_schedule[t])*image+(1-self.alpha_schedule[t])*noise
        return noisy_image, noise, torch.Tensor([t])


In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x*2-1)
])
celeba_with_noise = DataWithNoise(celebA, transform, 500)

In [None]:
class Level(nn.Module):
    def __init__(self, inch, och, dim):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(inch)
        self.conv1 = nn.Conv2d(inch, och, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(och)
        self.conv2 = nn.Conv2d(och, och, 3, 1, 1)
        self.act = nn.ReLU()
        self.time_embedding = nn.Sequential(
            Sinusoidal_embedding(dim),
            nn.Linear(dim, och),
            nn.ReLU()
        )
    
    def __call__(self, x, t):
        return self.forward(x, t)
    
    def forward(self, x, t):
        x = self.bn1(x)
        x = self.conv1(x)
        x = self.act(x)

        #embedding = self.time_embedding(t)
        #embedding = embedding[(..., ) + (None, ) * 2]
        #x = x+embedding

        x = self.bn2(x)
        x = self.conv2(x)
        x = self.act(x)

        return x

class Upconv(nn.Module):
    def __init__(self, inch, och):
        super().__init__()
        self.bn = nn.BatchNorm2d(inch)
        self.conv = nn.ConvTranspose2d(inch, och, 2, 2)
        self.act = nn.ReLU()

    def __call__(self, x):
        return self.forward(x)
    
    def forward(self, x):
        x = self.bn(x)
        x = self.conv(x)
        x = self.act(x)

        return x
    
class Sinusoidal_embedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def __call__(self, t):
        return self.forward(t)
    
    def forward(self, t):
        device = t.device
        half_dim = self.dim//2
        exp_sin = torch.linspace(0, 1, half_dim, device=device)
        exp_cos = torch.linspace(0, 1, half_dim, device=device)
        exp_sin = torch.sin(t*10000**exp_sin)
        exp_cos = torch.cos(t*10000**exp_cos)
        embedding = torch.cat((exp_sin, exp_cos), dim=-1)
        return embedding


class Unet(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.downlevel1 = Level(3, 64, dim)
        self.downlevel2 = Level(64, 128, dim)
        self.downlevel3 = Level(128, 256, dim)
        self.downlevel4 = Level(256, 512, dim)
        self.downlevel5 = Level(512, 1024, dim)
        self.uplevel1 = Level(1024, 512, dim)
        self.uplevel2 = Level(512, 256, dim)
        self.uplevel3 = Level(256, 128, dim)
        self.uplevel4 = Level(128, 64, dim)
        self.pool = nn.MaxPool2d(2, 2)
        self.upconv1 = Upconv(1024, 512)
        self.upconv2 = Upconv(512, 256)
        self.upconv3 = Upconv(256, 128)
        self.upconv4 = Upconv(128, 64)
        self.bn1 = nn.BatchNorm2d(64)
        self.lastconv = nn.Conv2d(64, 3, 1)
        self.bn2 = nn.BatchNorm2d(3)
        self.act = nn.Sigmoid()
    
    def __call__(self, x, t):
        return self.forward(x, t)
    
    def forward(self, x, t):
        x = self.downlevel1(x, t)
        residual1 = x
        x = self.pool(x)
        
        x = self.downlevel2(x, t)
        residual2 = x
        x = self.pool(x)

        x = self.downlevel3(x, t)
        residual3 = x
        x = self.pool(x)

        x = self.downlevel4(x, t)
        residual4 = x
        x = self.pool(x)

        x = self.downlevel5(x, t)

        x = self.upconv1(x)
        x = torch.cat([residual4, x], 1)
        x = self.uplevel1(x, t)

        x = self.upconv2(x)
        x = torch.cat([residual3, x], 1)
        x = self.uplevel2(x, t)
        
        x = self.upconv3(x)
        x = torch.cat([residual2, x], 1)
        x = self.uplevel3(x, t)

        x = self.upconv4(x)
        x = torch.cat([residual1, x], 1)
        x = self.uplevel4(x, t)

        x = self.bn1(x)
        x = self.lastconv(x)
        #x = self.bn2(x)
        #x = self.act(x)

        return x

In [None]:
def train_loop(num_epoch, model, optimizer, scheduler, train_loader, criterion):
    for epoch in range(num_epoch):
        tepoch = tqdm(train_loader, leave=True, position=0)
        sum_loss = 0
        i = 0
        for noisy_image, noise, t in tepoch:
            i+=1
            optimizer.zero_grad()

            noisy_image = noisy_image.cuda()
            noise = noise.cuda()
            t = t.cuda()

            output = model(noisy_image, t)
            loss = criterion(noise, output)
            loss.backward()

            optimizer.step()

            sum_loss += loss.item()
            running_loss = sum_loss / (i*BATCH_SIZE)

            tepoch.set_postfix(epoch = epoch, loss=running_loss)
            tepoch.refresh()

            scheduler.step()

In [None]:
lr = 1e-4
lr_ratio = 1e-3
num_epoch = 3
model = Unet(32).cuda()

def criterion(x, y):
    return torch.sum((x-y)**2)

optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = lr_scheduler(optimizer, num_epoch, lr*lr_ratio)
#train_loader = DataLoader(Flowers102_with_noise, batch_size=BATCH_SIZE, shuffle=True)
train_loader = DataLoader(celeba_with_noise, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
torch.save(model.state_dict(), './model_celeb_A.pt')