In [1]:
from dataset import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))  # 0 corresponds to the first GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Working on device: ", device)

TRAIN_SIZE = 10
TEST_SIZE = 10

True
NVIDIA GeForce RTX 3050 Ti Laptop GPU
Working on device:  cuda:0


In [3]:
train_dataset, test_dataset = prepare_dataset(TRAIN_SIZE,TEST_SIZE)
train_loader, test_loader = prepare_dataloader(train_dataset, test_dataset)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  return torch.tensor(gray_img)


'All tests passed'

Dataset loaded successfully
Data loader prepared successfully


In [4]:
'''
Encoder is a pretrained VGG up to relu4_1 as in the original paper (see 6.1 paper)
'''
class VGG_Encoder(torch.nn.Module):
    def __init__(self):
        super(VGG_Encoder, self).__init__()
        pretrained = torchvision.models.vgg19(pretrained=True)
        
        f = torch.nn.Sequential(*list(pretrained.features.children())[:21]).eval()

        ## adding an extra conv layer because we have 1 channel images
        self.adjuster = torch.nn.Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        
        # Splitting the network so we can get output of different layers
        # TODO: ADD REFLECTION PADDING LAYERS
        self.relu1_1 = torch.nn.Sequential(*f[:2],)
        self.relu2_1 = torch.nn.Sequential(*f[2:5], *f[5:7])
        self.relu3_1 = torch.nn.Sequential(*f[7:10],*f[10:12])
        self.relu4_1 = torch.nn.Sequential(*f[12:14],
                                          *f[14:16],
                                          *f[16:19],
                                           *f[19:21])
        
        for param in self.relu1_1.parameters():
            param.requires_grad = False
        for param in self.relu2_1.parameters():
            param.requires_grad = False
        for param in self.relu3_1.parameters():
            param.requires_grad = False
        for param in self.relu4_1.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.adjuster(x)
        out_1 = self.relu1_1(x)
        out_2 = self.relu2_1(out_1)
        out_3 = self.relu3_1(out_2)
        result = self.relu4_1(out_3)
        return out_1, out_2, out_3, result

def mean_and_std(x):
    x = x.view(x.shape[0], x.shape[1], -1)
    mean = x.mean(dim=2) + 0.00005
    std = x.var(dim=2).sqrt()
    return mean.view(mean.shape[0], mean.shape[1], 1, 1), std.view(std.shape[0], std.shape[1], 1, 1)

In [5]:
encoder = VGG_Encoder()

# print(encoder.adjuster(torch.rand(4,1,256,256)).shape)

print(encoder)
out_1,out_2,out_3,out_4= encoder(torch.rand(4,1,256,256))



VGG_Encoder(
  (adjuster): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
  (relu2_1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
  )
  (relu3_1): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
  )
  (relu4_1): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (

In [6]:
print(out_1.shape,out_2.shape,out_3.shape,out_4.shape)

torch.Size([4, 64, 256, 256]) torch.Size([4, 128, 128, 128]) torch.Size([4, 256, 64, 64]) torch.Size([4, 512, 32, 32])


In [13]:
'''
https://medium.com/analytics-vidhya/unet-implementation-in-pytorch-idiot-developer-da40d955f201
'''

class conv_block(torch.nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(out_c)        
        self.conv2 = torch.nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(out_c)         
        self.relu = torch.nn.ReLU()     
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)        

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x
    

class decoder_block(torch.nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = torch.nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)     
        
    def forward(self, inputs, skip):
        x = self.up(inputs)
        print("shape of inputs",inputs.shape,"shape of skip",skip.shape)  
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x
    

class Unet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = VGG_Encoder()


        ## ----------------- v1 ----------------

        ## Bottleneck
        self.bottleneck = conv_block(512, 1024)
        
        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        ## ----------------- v2 ----------------
        # """ Decoder """
        # self.d1 = decoder_block(512, 256)
        # self.d2 = decoder_block(256, 128)
        # self.d3 = decoder_block(128, 64)
        # self.d4 = decoder_block(64, 32)


        ## output should be 3 channels image
        self.out = torch.nn.Conv2d(64, 3, kernel_size=1, padding=0)

    def forward(self, x):
        ## Encoder
        out1, out2, out3, out4 = self.encoder(x)
        
        ## Decoder
        
        d1 = self.d1(b, out4)
        d2 = self.d2(d1, out3)
        d3 = self.d3(d2, out2)
        d4 = self.d4(d3, out1)
        
        out = self.out(d4)
        return out

In [14]:
model = Unet().to(device)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

NUM_EPOCHS = 2

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader):
        X = batch['grayscale_image'].to(device)
        y = batch['image'].to(device)

        out = model(X)
        loss = criterion(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} loss: {epoch_loss:.4f}")

  0%|          | 0/3 [00:00<?, ?it/s]


shape of inputs torch.Size([4, 512, 32, 32]) shape of skip torch.Size([4, 512, 32, 32])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 64 but got size 32 for tensor number 1 in the list.

In [None]:
# ''' 
# decoder is just the second part of an Unet
# implement skip connections (feed concat to the upsample layer)
# '''
# class Decoder(torch.nn.Module):
#     def __init__(self):
#         super(Decoder, self).__init__()
#         ## TODO: adapt block one for single channel input
#         self.block1 = torch.nn.Sequential(
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(512, 256, (3, 3)),
#             torch.nn.ReLU())
#         self.block2 = torch.nn.Sequential(
#             torch.nn.Upsample(scale_factor=2, mode='nearest'),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(256, 256, (3, 3)),
#             torch.nn.ReLU(),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(256, 256, (3, 3)),
#             torch.nn.ReLU(),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(256, 256, (3, 3)),
#             torch.nn.ReLU(),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(256, 128, (3, 3)),
#             torch.nn.ReLU(),
#         )
#         self.block3 = torch.nn.Sequential(
#             torch.nn.Upsample(scale_factor=2, mode='nearest'),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(128, 128, (3, 3)),
#             torch.nn.ReLU(),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(128, 64, (3, 3)),
#             torch.nn.ReLU(),
#         )
#         self.block4 = torch.nn.Sequential(
#             torch.nn.Upsample(scale_factor=2, mode='nearest'),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(64, 64, (3, 3)),
#             torch.nn.ReLU(),
#             torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#             torch.nn.Conv2d(64, 3, (3, 3)),
#         )


#         # self.decode = torch.nn.Sequential(
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(512, 256, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.Upsample(scale_factor=2, mode='nearest'),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(256, 256, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(256, 256, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(256, 256, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(256, 128, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.Upsample(scale_factor=2, mode='nearest'),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(128, 128, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(128, 64, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.Upsample(scale_factor=2, mode='nearest'),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(64, 64, (3, 3)),
#         #     # torch.nn.ReLU(),
#         #     # torch.nn.ReflectionPad2d((1, 1, 1, 1)),
#         #     # torch.nn.Conv2d(64, 3, (3, 3)),
#         # )
#     def forward(self, x, skips):
#         '''
#         skips is a list of tensors from the encoder
#         concatenation might not work because we have different depth channels
#         a 1x1 convolution is usually applied after that as a fusion mechanism
#         '''
#         out1 = self.block1(x)

#         ## skip connection
#         out2 = self.block2(torch.cat([out1,x], axis=1))
#         out3 = self.block3(torch.cat([out2,x], axis=1))
        
#         return result
# """
# decode = Decoder()
# img = decode(t)
# concat_img((img[:12]).detach().cpu())
# """

In [None]:
## try adain before skip connections or without to see if it makes a difference
encoder = VGG_Encoder()
decoder = Decoder()
## generate random tensor of at least 4 dimensions
random_tensor = torch.rand((12, 16, 26, 26))
style_image  = torch.rand((12, 16, 26, 26))
adain = AdaIN()
random_tensor = adain(random_tensor, style_image)
print(random_tensor.shape)

In [None]:
print(encoder.relu1_1)

In [None]:
x = encoder(torch.rand(1, 3, 256, 256))
print(x[3].shape)

through_adain = adain(x[3], x[3])

output = decoder(through_adain)
print(output.shape)

In [None]:
## AdaIN implementation
## TODO: see if the output size is the same as input size
class AdaIN(torch.nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
        self.IN = torch.nn.InstanceNorm2d(512)
    
    def forward(self, x, y):
        size = x.size()
        
        x = self.IN(x)
        
        #mean_x, std_x = mean_and_std(x)
        mean_y, std_y = mean_and_std(y)
        #x = (x - mean_x.expand(size)) / std_x.expand(size)
        x = x * std_y.expand(size) + mean_y.expand(size)
        return x
""""
print(style.shape)
mean, std = mean_and_std(style)
print(mean.shape)
print(std.shape)
Ada = AdaIN()
t = Ada(vgg(content)[3], vgg(style)[3])
"""


In [None]:
a = torch.nn.Conv2d(5, 20, kernel_size=1)
len(a.weight)