In [None]:
import random
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn
import torch,torchvision
from torch.nn import *
from sklearn.model_selection import *
from tqdm import tqdm
import cv2,os
from torch.optim import *
import pickle
import wandb

PROJECT_NAME = 'Trying-to-Turn-GrayScale-Images-to-Color-Images-I-dont-know-how-but-I-am-going-to-try'
device = 'cuda:0'
np.random.seed(21)
random.seed(21)
torch.manual_seed(21)

In [None]:
def load_data():
    data = []
    for file in tqdm(os.listdir('./data/')):
        X_one = cv2.imread(f'./data/{file}',cv2.IMREAD_GRAYSCALE)
        X_one = cv2.resize(X_one,(112,112))
        y_one = cv2.imread(f'./data/{file}')
        y_one = cv2.resize(y_one,(112,112))
        data.append([X_one,y_one])
    return data

In [None]:
data = load_data()

In [None]:
np.random.shuffle(data)

In [None]:
X = []
y = []

In [None]:
for d in data:
    X.append(d[0])
    y.append(d[1])

In [None]:
X = torch.from_numpy(np.array(X)).view(-1,1*112*112).to(device).float()
y = torch.from_numpy(np.array(y)).view(-1,3*112*112).to(device).float()

In [None]:
def predict(model):
    model.eval()
    for file in os.listdir('./test_data/'):
        X_one = cv2.imread(f'./data/{file}',cv2.IMREAD_GRAYSCALE)
        X_one = cv2.resize(X_one,(112,112))
        preds = model(torch.from_numpy(np.array(X_one)).to(device).float())
        preds = preds.view(-1,3*112*112)
        plt.figure(figsize=(12,6))
        plt.imshow(preds.view(112,112,3).cpu().detach().numpy()/255.0)
        plt.savefig(f'./preds/{file}')
        plt.close()
    model.train()

In [None]:
class Model_1(Module):
    def __init__(self):
        super().__init__()
        self.activation = ReLU()
        self.max_pool2d = MaxPool2d((2,2),(2,2))
        self.conv1 = Conv2d(1,7,(5,5))
        self.conv1bn = BatchNorm2d(7)
        self.conv2 = Conv2d(7,14,(5,5))
        self.conv2bn = BatchNorm2d(14)
        self.conv3 = Conv2d(14,21,(5,5))
        self.conv3bn = BatchNorm2d(21)
        self.linear1 = Linear(21*3*3,512)
        self.linear1bn = BatchNorm1d(512)
        self.linear2 = Linear(512,512)
        self.linear2bn = BatchNorm1d(512)
        self.linear3 = Linear(512,512)
        self.linear3bn = BatchNorm1d(512)
        self.output = Linear(512,3*112*112)
    
    def forward(self,X):
        X = X.view(-1,1,112,112)
        preds = self.max_pool2d(self.activation(self.conv1bn(self.conv1(X))))
        preds = self.max_pool2d(self.activation(self.conv2bn(self.conv2(preds))))
        preds = self.max_pool2d(self.activation(self.conv3bn(self.conv3(preds))))
        print(preds.shape)
        preds = preds.view(-1,21*3*3)
        preds = self.activation(self.linear1bn(self.linear1(preds)))
        preds = self.activation(self.linear2bn(self.linear2(preds)))
        preds = self.activation(self.linear3bn(self.linear3(preds)))
        preds = self.output(preds)
        return preds

In [None]:
class Model_2(Module):
    def __init__(self):
        super().__init__()
        self.activation = ReLU()
        self.linear1 = Linear(1*112*112,256)
        self.linear1bn = BatchNorm1d(256)
        self.linear2 = Linear(256,512)
        self.linear2bn = BatchNorm1d(512)
        self.linear3 = Linear(512,1024)
        self.linear3bn = BatchNorm1d(1024)
        self.linear4 = Linear(1024,1024)
        self.linear4bn = BatchNorm1d(1024)
        self.output = Linear(1024,3*112*112)
    
    def forward(self,X):
        preds = self.activation(self.linear1bn(self.linear1(X)))
        preds = self.activation(self.linear2bn(self.linear2(preds)))
        preds = self.activation(self.linear3bn(self.linear3(preds)))
        preds = self.activation(self.linear4bn(self.linear4(preds)))
        preds = self.output(preds)
        return preds

In [None]:
model = Model_1().to(device)

In [None]:
criterion = MSELoss()

In [None]:
optimizer = Adam(model.parameters(),lr=0.001)

In [None]:
epochs = 100

In [None]:
batch_size = 32

In [None]:
wandb.init(project=PROJECT_NAME,name='Model_1-no-backward')
for _ in tqdm(range(epochs)):
    for idx in range(0,len(X),batch_size):
        X_batch = X[idx:idx+batch_size].to(device).float().view(-1,1*112*112)
        y_batch = y[idx:idx+batch_size].to(device).float().view(-1,3*112*112)
        preds = model(X_batch)
        loss = criterion(preds,y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    wandb.log({'Loss':loss.item()})
    predict(model)
    for file in os.listdir('./preds/'):
        wandb.log({f'Pred Img/{file}':wandb.Image(cv2.imread(f'./preds/{file}')/255.0)})
        wandb.log({f'Real Img/{file}':wandb.Image(cv2.imread(f'./test_data/{file}',cv2.IMREAD_GRAYSCALE)/255.0)})
wandb.finish()