In [1]:
from dataset import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))  # 0 corresponds to the first GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Working on device: ", device)

TRAIN_SIZE = 100
TEST_SIZE = 10

True
NVIDIA GeForce RTX 3050 Ti Laptop GPU
Working on device:  cuda:0


In [3]:
train_dataset, test_dataset = prepare_dataset(TRAIN_SIZE,TEST_SIZE)
train_loader, test_loader = prepare_dataloader(train_dataset, test_dataset)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  return torch.tensor(gray_img)


'All tests passed'

Dataset loaded successfully
Data loader prepared successfully


In [4]:
'''
Encoder is a pretrained VGG up to relu4_1 as in the original paper (see 6.1 paper)
'''
class VGG_Encoder(torch.nn.Module):
    def __init__(self):
        super(VGG_Encoder, self).__init__()
        pretrained = torchvision.models.vgg19(pretrained=True)
        
        f = torch.nn.Sequential(*list(pretrained.features.children())[:21]).eval()

        # Splitting the network so we can get output of different layers
        # TODO: ADD REFLECTION PADDING LAYERS
        self.relu1_1 = torch.nn.Sequential(*f[:2],)
        self.relu2_1 = torch.nn.Sequential(*f[2:5], *f[5:7])
        self.relu3_1 = torch.nn.Sequential(*f[7:10],*f[10:12])
        self.relu4_1 = torch.nn.Sequential(*f[12:14],
                                          *f[14:16],
                                          *f[16:19],
                                           *f[19:21])
        
        for param in self.relu1_1.parameters():
            param.requires_grad = False
        for param in self.relu2_1.parameters():
            param.requires_grad = False
        for param in self.relu3_1.parameters():
            param.requires_grad = False
        for param in self.relu4_1.parameters():
            param.requires_grad = False

    def forward(self, x):
        out_1 = self.relu1_1(x)
        out_2 = self.relu2_1(out_1)
        out_3 = self.relu3_1(out_2)
        result = self.relu4_1(out_3)
        return out_1, out_2, out_3, result

def mean_and_std(x):
    x = x.view(x.shape[0], x.shape[1], -1)
    mean = x.mean(dim=2) + 0.00005
    std = x.var(dim=2).sqrt()
    return mean.view(mean.shape[0], mean.shape[1], 1, 1), std.view(std.shape[0], std.shape[1], 1, 1)

In [None]:
encoder = VGG_Encoder()
print(encoder)
out_1,out_2,out_3,out_4= encoder(torch.rand(4,3,256,256))
print(out_1.shape,out_2.shape,out_3.shape,out_4.shape)

In [None]:
'''
https://medium.com/analytics-vidhya/unet-implementation-in-pytorch-idiot-developer-da40d955f201
'''

class conv_block(torch.nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(out_c)        
        self.conv2 = torch.nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(out_c)         
        self.relu = torch.nn.ReLU()     
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)        

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x
    

class decoder_block(torch.nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)     
        
    def forward(self, inputs, skip):
        display("shape of inputs",inputs.shape,"shape of skip",skip.shape)  
        x = self.up(inputs)
        # x = torch.cat([x, skip], axis=1)
        display("Concatenation successful",x.shape)
        x = self.conv(x)
        return x
    

class Unet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = VGG_Encoder()


        ## ----------------- v1 ----------------

        # ## Bottleneck
        # self.bottleneck = conv_block(512, 1024)
        
        # """ Decoder """
        # self.d1 = decoder_block(1024, 512)
        # self.d2 = decoder_block(512, 256)
        # self.d3 = decoder_block(256, 128)
        # self.d4 = decoder_block(128, 64)

        ## ----------------- v2 ----------------
        # """ Decoder """
        self.d0 = decoder_block(512, 512)
        self.d1 = decoder_block(512, 256)
        self.d2 = decoder_block(256, 128)
        self.d3 = decoder_block(128, 64)
        self.d4 = decoder_block(64, 32)


        ## output should be 3 channels image
        self.out = torch.nn.Conv2d(64, 3, kernel_size=1, padding=0)

    def forward(self,x):
        ## Encoder
        out1, out2, out3, out4 = self.encoder(x)

        # x = self.d0(x, out4)
        # x = self.d1(x, out3)
        # x = self.d2(x, out2)
        # x = self.d3(x, out1)
        # x = self.d4(x, x)

        x = self.d0(out4, out4)
        x = self.d1(x, out3)
        x = self.d2(x, out2)
        x = self.d3(x, out1)
        x = self.d4(x, out1)

        return self.out(x)
    
    # def forward(self, x):
    #     ## Encoder
    #     out1, out2, out3, out4 = self.encoder(x)

    #     ## bottleneck
    #     out = self.bottleneck(out4)

    #     ## decoder
    #     d1 = self.d1(out, out4)
    #     d2 = self.d2(d1, out3)
    #     d3 = self.d3(d2, out2)
    #     d4 = self.d4(d3, out1)
                
        
    #     out = self.out(d4)
    #     return out

In [None]:
a = torch.rand(4,3,256,256)
b = torch.rand(4,5,256,256)
con = torch.cat([a,b], axis=1)
print(con.shape)

In [None]:
model = Unet().to(device)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

NUM_EPOCHS = 2

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader):
         
        grayscale_X = batch['grayscale_image']
        X = grayscale_X.repeat(1,3,1,1).to(device)
        y = batch['image'].to(device)

        out = model(X)
        loss = criterion(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} loss: {epoch_loss:.4f}")

-------------------------------------------------------------

In [5]:
''' 
decoder is just the second part of an Unet
implement skip connections (feed concat to the upsample layer)
'''
class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.decode = torch.nn.Sequential(
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(512, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 128, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(128, 128, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(128, 64, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(64, 64, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(64, 3, (3, 3)),
        )
        
    def forward(self, x):
        return self.decode(x)
    
class Unet(torch.nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        
        self.encoder = VGG_Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        out_1, out_2, out_3, out_4 = self.encoder(x)
        return self.decoder(out_4)

In [6]:
model = Unet().to(device)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

NUM_EPOCHS = 20

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader):
         
        grayscale_X = batch['grayscale_image']
        X = grayscale_X.repeat(1,3,1,1).to(device)
        y = batch['image'].to(device)

        out = model(X)
        loss = criterion(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X.size(0)
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} loss: {epoch_loss:.4f}")

100%|██████████| 25/25 [00:05<00:00,  4.44it/s]


Epoch 1 loss: 0.6611


100%|██████████| 25/25 [00:04<00:00,  6.22it/s]


Epoch 2 loss: 0.1079


100%|██████████| 25/25 [00:04<00:00,  6.01it/s]


Epoch 3 loss: 0.1070


100%|██████████| 25/25 [00:04<00:00,  5.98it/s]


Epoch 4 loss: 0.0978


100%|██████████| 25/25 [00:04<00:00,  6.03it/s]


Epoch 5 loss: 0.0967


100%|██████████| 25/25 [00:04<00:00,  5.96it/s]


Epoch 6 loss: 0.0924


100%|██████████| 25/25 [00:04<00:00,  5.97it/s]


Epoch 7 loss: 0.0747


100%|██████████| 25/25 [00:04<00:00,  5.99it/s]


Epoch 8 loss: 0.0763


100%|██████████| 25/25 [00:04<00:00,  6.10it/s]


Epoch 9 loss: 0.0647


100%|██████████| 25/25 [00:03<00:00,  6.30it/s]


Epoch 10 loss: 0.0672


100%|██████████| 25/25 [00:04<00:00,  5.91it/s]


Epoch 11 loss: 0.0617


100%|██████████| 25/25 [00:04<00:00,  6.00it/s]


Epoch 12 loss: 0.0668


100%|██████████| 25/25 [00:04<00:00,  5.83it/s]


Epoch 13 loss: 0.0577


100%|██████████| 25/25 [00:04<00:00,  5.98it/s]


Epoch 14 loss: 0.0583


100%|██████████| 25/25 [00:04<00:00,  6.24it/s]


Epoch 15 loss: 0.0611


100%|██████████| 25/25 [00:04<00:00,  6.19it/s]


Epoch 16 loss: 0.0659


100%|██████████| 25/25 [00:04<00:00,  6.09it/s]


Epoch 17 loss: 0.0574


100%|██████████| 25/25 [00:04<00:00,  5.92it/s]


Epoch 18 loss: 0.0588


 64%|██████▍   | 16/25 [00:02<00:01,  6.39it/s]

In [None]:
test_image = list(train_dataset)[5]['image']
lab2rgb = LABtoRGB()
rgb_image = lab2rgb(test_image)

In [None]:
output = model(test_image.unsqueeze(0).to(device))

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def display_rgb_image(rgb_tensor):
    # Convert the tensor to a numpy array
    rgb_array = rgb_tensor.numpy()
    
    # Transpose the array to match the required format for displaying using PIL
    rgb_array = np.transpose(rgb_array, (1, 2, 0))
    
    # Ensure the data type is uint8 and scale values to [0, 255]
    rgb_array = (rgb_array * 255).astype(np.uint8)
    
    # Convert the numpy array to a PIL Image
    img = Image.fromarray(rgb_array)
    
    # Display the image
    plt.imshow(img)
    plt.axis('off')
    plt.show()

print(test_image.shape)
print(rgb_image.shape)
# Assuming rgb_image_tensor is your tensor of shape [3, 256, 256]
display_rgb_image(rgb_image)


In [None]:
## AdaIN implementation
## TODO: see if the output size is the same as input size
class AdaIN(torch.nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
        self.IN = torch.nn.InstanceNorm2d(512)
    
    def forward(self, x, y):
        size = x.size()
        
        x = self.IN(x)
        
        #mean_x, std_x = mean_and_std(x)
        mean_y, std_y = mean_and_std(y)
        #x = (x - mean_x.expand(size)) / std_x.expand(size)
        x = x * std_y.expand(size) + mean_y.expand(size)
        return x
""""
print(style.shape)
mean, std = mean_and_std(style)
print(mean.shape)
print(std.shape)
Ada = AdaIN()
t = Ada(vgg(content)[3], vgg(style)[3])
"""
