# Global settings

In [1]:
use_model_weights=False
dataset_path='./dataset/'
model_weights_path = './model_weights/model_weights.pth'
model = ''

In [2]:
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch
import math, time
from tqdm import tqdm
from lion_pytorch import Lion
import torch.optim.lr_scheduler as lr_scheduler
from torch import device
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from gc import collect
from os.path import isfile
from os import remove
from psutil import virtual_memory
from torchvision import transforms
_ = torch.manual_seed(42)



# Loading data

In [3]:
# Get images
# image = Image.open(dataset_path+'woman.jpg').convert('RGB')
# image = np.array(image, dtype=int)

# Funkcija za konvertovanje slike iz sRGB u L\*a\*b\*

In [4]:
def convert_rgb_to_lab(image: np.ndarray) -> np.ndarray: 
    """
    Ova funkcija konvertuje image iz sRGB prostora u L*a*b* prostor.
    
    Parameters
    -----------
    image : np.NDArray, shape: Tuple[int,int,int[,int]]
        Slika(ili slike) treba da bude 3D/4D numpy array, gde je prva dimenzija redni broj slike(ako ih ima vise), druga dimenzija broj kanala, a treca i cetvrta sirina x visina. 
    
    Returns
    -------
        Returns a `np.ndarray` of shape (num_of_images, 3, height, width) if `image` is 4D, else if `image` is 3D it returns (3, height, width).
    """
    assert 3 in image.shape, f"Nije pronadjena nijedna dimenzija koja je =3"
    assert 3 <= image.ndim <= 4, f"Ocekivani broj dimenzija ulaznog parametra je izmedju 3 i 4(inclusive), a dobijeno je {image.ndim}"
    if image.ndim == 3:
        assert image.shape[0] == 3, f"Pogresna dimenzija na poziciji shape[0], ocekivano 3, dobijeno {image.shape[0]}"
        lab_image = rgb2lab(1.0/255*image, channel_axis=0)
        lab_image[1:] = lab_image[1:] / 128
        # X = lab_image[0]
        # Y = lab_image[1:]
        assert np.all( ( lab_image[0] >= 0) & ( lab_image[0] <= 100) ), f"U L* kanalu pronadjeno nedozvoljenih vrednosti"
        assert np.all( ( lab_image[1:] >= -1) & ( lab_image[1:] <= 1) ), f"U a*b* kanalima pronadjeno nedozvoljenih vrednosti"
    elif image.ndim == 4:
        assert image.shape[1] == 3, f"Pogresna dimenzija na poziciji shape[1], ocekivano 3, dobijeno {image.shape[1]}"
        lab_image = rgb2lab(1.0/255*image, channel_axis=1)
        lab_image[:,1:] = lab_image[:,1:] / 128
        assert np.all( (lab_image[:,0] >= 0 ) & (lab_image[:,0] <= 100 )), f"U L* kanalu pronadjeno nedozvoljenih vrednosti"
        assert np.all( (lab_image[:,1:] >= -1 ) & (lab_image[:,1:] <= 1 )), f"U a*b* kanalima pronadjeno nedozvoljenih vrednosti"
        # X = lab_image[:,0]
        # Y = lab_image[:,1:]

    return lab_image
    # Y /= 128
    # X = X.reshape(1 if image.ndim == 3 else image.shape[0], 1, image.shape[1 if image.ndim == 3 else 2], image.shape[2 if image.ndim == 3 else 3])
    # assert np.all( (X >= 0) & (X <= 100) , axis=(1,2) if image.ndim == 3 else (2,3)).any(), f"U L* kanalu pronadjeno nedozvoljenih vrednosti"
    # Y = Y.reshape(1 if image.ndim == 3 else image.shape[0], 2, image.shape[1 if image.ndim == 3 else 2], image.shape[2 if image.ndim == 3 else 3])
    # assert np.all( (Y >= -1) & (Y <= 1) ), f"U L* kanalu pronadjeno nedozvoljenih vrednosti"
    # return X,Y

In [5]:
# NOTE: testiramo ispravnost
# img = convert_rgb_to_lab(image.reshape(1,3,400,400))
# X,Y = img[:,0:1,:,:], img[:,1:,:,:]

# Funkcija za reverse convert

In [6]:
def convert_lab_to_rgb(image: np.ndarray, denormalize=False) -> np.ndarray:
    """
    Ova funkcija konvertuje iz L*a*b* prostora u sRGB prostor.

    Parameters
    ----------
    image: np.NDArray, shape: Tuple[int,int,int[,int]]
        Slika(ili slike) treba da bude 3D/4D numpy array, gde je prva dimenzija redni broj slike(ako ih ima vise), druga dimenzija oznacava kanal(L*, a* ili b*), a treca i cetvrta sirina x visina. L* kanal mora da sadrzi vrednosti od 0 do 100, dok a* i b* moraju imati vrednosti izmedju -128 i 127.
    denormalize : boolean=False
        Da li denormalizovati podatke. Ako je denormalize=`True`, onda se koristi sRGB opseg [0,255], u protivnom se koristi [0,1]

    Returns
    -------
    Vraca nam sliku(ili slike) u `numpy.ndarray` formatu.   
    """
    assert 3 in image.shape, f"Nije pronadjena nijedna dimenzija koja je =3"
    assert 3 <= image.ndim <= 4, f"Ocekivani broj dimenzija ulaznog parametra je izmedju 3 i 4(inclusive), a dobijeno je {image.ndim}"
    if image.ndim == 3:
        assert image.shape[0] == 3, f"Pogresna dimenzija na poziciji shape[0], ocekivano 3, dobijeno {image.shape[0]}"
        rgb_image = lab2rgb(image, channel_axis=0)
    elif image.ndim == 4:
        assert image.shape[1] == 3, f"Pogresna dimenzija na poziciji shape[1], ocekivano 3, dobijeno {image.shape[1]}"
        rgb_image = lab2rgb(image, channel_axis=1)
    rgb_image = rgb_image.reshape(1 if image.ndim == 3 else image.shape[0], 3, image.shape[1 if image.ndim == 3 else 2], image.shape[2 if image.ndim == 3 else 3])
    assert np.all( (rgb_image >= 0) & (rgb_image <=1.0)), f"Ocekivani opseg RGB vrednosti 0-1 je prekrsen"
    return rgb_image if not denormalize else (rgb_image * 255).astype(np.uint8)

In [7]:
# NOTE: testiramo ispravnost
# imgs_back_2rgb = convert_lab_to_rgb( np.concatenate( (X,Y), axis=1))

# Ucitavanje i pripremu dataseta

## Funkcija za ucitavanje i kreiranje memory-mapped dataseta

In [8]:
def load_data(mmap_mode = None, percentage:int = 0.7, shape=(25000,3,224,224)):
    """
    Treba da ucita podatke sa diska kao memory map. Memory-mapped podaci se ne ucitavaju svi u memoriju, vec se ucitavaju sa diska direktno po potrebi. 

    Parameters
    ----------
    mmap_mode : str | None
        U kom rezimu treba da ucitamo finalni dataset. Za vise videti [link](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html#numpy.memmap). Ako je `None`, ucitace ceo dataset u memoriju.
    percentage : int = 0.7
        Koliko procenata dostupne sistemske memorije zelimo iskoristiti za konverziju.
    
    Returns
    -------
        Ucitani L*a*b* dataset `np.ndarray`, ako je mmap_mode=`None`, u protivnom vraca `memmap` i cita se sa diska.
    """
    should_delete_dataset=False
    assert 0.1 < percentage < 0.9, f"Ocekivan opseg procenata [10%,90%], dobijeno {percentage}"
    if isfile(dataset_path + 'lab_dataset.npy'):
        return np.memmap(dataset_path + 'lab_dataset.npy',mode=mmap_mode, shape=shape,dtype="float16")
    elif isfile(dataset_path + 'joined_dataset.npy'):
        joined_dataset = np.load(dataset_path + 'joined_dataset.npy', mmap_mode='r') # ucitavamo zdruzeni dataset
        dataset_shape = joined_dataset.shape
        dataset = np.memmap(dataset_path + 'lab_dataset.npy', mode="w+", shape=(dataset_shape[0],3,224,224), dtype="float16") # kreiramo memory-mapped fajl za finalni dataset (7 GB).
        available_system_memory_in_GBs = virtual_memory().available/1024**3
        how_much_memory_to_reserve_for_conversion_in_GBs = available_system_memory_in_GBs * percentage
        converted_image_size_in_GBs = dataset_shape[0] * 8 * 3 * 224 * 224 / 1024**3  # 28 GB u sustini ako sve odjednom konvertujem.
        print(f"Dostupna memorija za konverziju slika {how_much_memory_to_reserve_for_conversion_in_GBs}")
        print(f"Memorija potrebna za konverziju slika {converted_image_size_in_GBs}")

        try:
            if how_much_memory_to_reserve_for_conversion_in_GBs - converted_image_size_in_GBs >= 1: # Ostavljamo 1 GB overhead-a
                image = convert_rgb_to_lab(joined_dataset)
                dataset[:] = image
            else: # U protivnom koristimo batched obradu.
                batch_size = int( (how_much_memory_to_reserve_for_conversion_in_GBs-1)*dataset_shape[0] / (32*converted_image_size_in_GBs) ) # -1 zbog memory overheada. 32/64 jer rgb2lab koristi float64...
                print(f"Batch size:\t{batch_size}")
                assert batch_size > 1, f"Nemate dovoljno memorije za ovakvu operaciju, ocekivano je da batch_size bude veci od 1, ali je {batch_size}"
                for i in range(0, dataset_shape[0], batch_size+1):
                    dataset[i:batch_size] = convert_rgb_to_lab( joined_dataset[i:batch_size] )
                dataset[i:] = convert_rgb_to_lab( joined_dataset[i:] )
        except MemoryError as e:
            print(f"Doslo je do greske:\n{e}")
            should_delete_dataset = True
        except Exception as e:
            should_delete_dataset = True
            print(e)
        finally:
            dataset.flush()
            print("Flushed!")
            del dataset, joined_dataset
            collect()
            if should_delete_dataset:
                remove(dataset_path + 'lab_dataset.npy')
                return

        print("Kreirani finalni dataset")
        return load_data(mmap_mode, percentage, dataset_shape)
    else:
        X = np.load(dataset_path + 'l/gray_scale.npy',mmap_mode='r').reshape(25000,1,224,224)
        y_file_to_load = [ f'ab/ab/ab{i}.npy' for i in range(1,4)]
        Y = np.concatenate( [ np.load(dataset_path + file) for file in y_file_to_load ], axis=0 ).reshape(25000,2,224,224)
        np.save(dataset_path + 'joined_dataset.npy', np.concatenate( (X,Y), axis=1))
        del X,Y,y_file_to_load
        collect()
        print("Kreirani zdruzeni dataset")
        return load_data(mmap_mode, percentage)

## Definisanje transformacija slika

In [9]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=20, shear=0.2, scale=(0.8, 1.2)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
])

## Klasa naseg custom dataseta

In [10]:
# definisanje custom dataseta
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None):
        """
        Inicijalizuje dataset

        Parameters
        ----------
        data : array-like
            Nas dataset, oblika (broj_slika, 3, height, width).
        transform(optional) : torchvision.Compose
            Transformacije koje koje primenjujemo nad nasim slikama.
        """
        self.data = data
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        image = self.data[index]
        # image = Image.fromarray(image)
        if self.transform:
            print("be4 transform")
            image = self.transform(image)
            print("after transform")
        print(image)
        print(image.shape)
        X = image[0]
        Y = image[1:]
        return X,Y

## Ucitavanje podataka, podela na trening i test

In [11]:
dataset_memory_map = load_data('r')

In [12]:
dataset_memory_map.shape

(25000, 3, 224, 224)

In [13]:
test_size = 0.2
validation_size = 0.15

test_length = int(test_size * dataset_memory_map.shape[0])
remaining_length = dataset_memory_map.shape[0] - test_length
validation_length = int(remaining_length * validation_size)
training_length = remaining_length - validation_length

training_length, validation_length, test_length

indices = [i for i in range(dataset_memory_map.shape[0])]

training_indices = np.random.choice(range(0, dataset_memory_map.shape[0]), size=training_length, replace=False)
validation_indices = np.random.choice(list(set(range(0, dataset_memory_map.shape[0])) - set(training_indices)), size=validation_length, replace=False)
test_indices = np.random.choice(list(set(range(0, dataset_memory_map.shape[0])) - set(training_indices) - set(validation_indices)), size=test_length, replace=False)

training = ImageDataset(dataset_memory_map[training_indices], transform=transform)
validation = ImageDataset(dataset_memory_map[validation_indices], transform=transform)
# test = ImageDataset(dataset_memory_map[test_indices], transform=transform)

# Model class

In [14]:
class ColorizerModel(nn.Module):
    def __init__(self):
        super(ColorizerModel, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=2)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2)
        self.conv7 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv10 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv11 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(32, 2, kernel_size=3, padding=1)
        self.tanh = nn.Tanh()
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.ReLU()(self.conv2(x))
        x = nn.ReLU()(self.conv3(x))
        x = nn.ReLU()(self.conv4(x))
        x = nn.ReLU()(self.conv5(x))
        x = nn.ReLU()(self.conv6(x))
        x = nn.ReLU()(self.conv7(x))
        x = nn.ReLU()(self.conv8(x))
        x = nn.ReLU()(self.conv9(x))
        x = self.upsample1(x)
        x = nn.ReLU()(self.conv10(x))
        x = self.upsample2(x)
        x = nn.ReLU()(self.conv11(x))
        x = self.tanh(self.conv12(x))
        x = self.upsample3(x)
        return x

# Early stopping

In [15]:
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False, also_use_timer=False, seconds_to_terminate:int=60*60):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self.also_use_timer=also_use_timer
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

        if also_use_timer:
            self.start_time=time.perf_counter()
            self.end_time = 0
            self.time_compare = lambda start,end: end-start >= seconds_to_terminate # NOTE Terminate after an hour
        else:
            self.start_time=None
            self.end_time=None
            self.time_compare = lambda start,end: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.is_tensor(metrics):
            if torch.isnan(metrics):
                return True
        elif type(metrics) == float and math.isnan(metrics):
              return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            print('terminating because of early stopping!')
            return True

        return False

    def time_ran_out(self):
        if self.time_compare(self.start_time, self.end_time):
            print("Terminating because of training time limit.")
            return True
        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

# Functions for training and evaluating a model

In [16]:
def evaluate(model: nn.Module, validation: DataLoader, device, metric, is_called_from_training=False):
    model.eval()
    img_real = []
    img_pred = []
    for step, batch in enumerate(validation):
        with torch.no_grad():
            inputs, targets = batch
            if(not is_called_from_training):
                inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            img_real.append(targets.cpu().numpy().reshape(validation.batch_size, 3, 224, 224))#NOTE: mozda puca
            img_pred.append(outputs.cpu().numpy().reshape(validation.batch_size, 3, 224, 224))#NOTE: mozda puca

    model.train()
    metric.update(np.concatenate(img_real, axis=0), real=True)
    metric.update(np.concatenate(img_real, axis=0), real=False)
    return metric.compute()

def fit(model: nn.Module, optimizer: optim.Optimizer, training:DataLoader, validation: DataLoader, scheduler: lr_scheduler.LRScheduler, metric_for_early_stopping,  epochs:int=50, loss_fn=nn.MSELoss(), gradient_accumulation_steps:int=8,enable_early_stopping:bool=True,patience:int=7,early_stopping_mode:str='min',delta_for_early_stopping:float=0,best:float=None,also_use_timer_for_early_stopping:bool=False, seconds_for_early_stopping:int=60*60, device:str='cpu'):
    """
    Trenira/fituje model.

    Parameters
    ----------
    model : nn.Module
        Ovo je objekat instanciranog modela kojeg treniramo
    optimizer : optim.Optimizer
        Optimizator parametara `model` koje koristimo
    training : DataLoader
        DataLoader za trening
    validation : DataLoader
        DataLoader za validaciju
    epochs : int, optional
        Broj epoha prilikom treninga
    loss_fn : optional
        Funkcija za generisanja loss-a tokom treninga `model`-a.
    scheduler : LRSCheduler 
        Scheduler za `learning_rate` 
    gradient_accumulation_steps : int, optional
        Koliko step-ova akumuliramo gradijente pre nego sto uradimo apdejt vejtova. Ako ne zelimo akumuliranje gradijenata, setovati ovaj parametar na 1.
    early_stopping_mode : str, optional
        Rezim rada early stopping mehanizma(moze biti `min` ili `max`)
    patience : int, optional
        Koliko koraka u EarlyStoppingu tolerisemo pre nego sto prekinemo trening
    delta_for_early_stopping : float, optional
        Tolerancija odstupanja performansi za early stopping
    metric_for_early_stopping : str, optional
        Koju metriku cemo koristiti za early stopping. 
    best : float, optional
        Najbolji rezultat koji je model postigao. Podrazumevano nema, ako instanciramo model od 0.
    also_use_timer_for_early_stopping : bool, optional
        Da li se koristi i tajmer za early stopping(ako npr. zelimo da trening traje odredjeno vreme)
    device : {'cpu', 'cuda'}
        Na kojem uredjaju zelimo da se vrsi trening.
    """
    model = model.to(device)

    trainingSteps = epochs * len(training)
    
    if enable_early_stopping:
        earlyStopping = EarlyStopping(patience=min(epochs, patience), mode=early_stopping_mode, min_delta=delta_for_early_stopping,also_use_timer=also_use_timer_for_early_stopping, seconds_to_terminate=seconds_for_early_stopping)
        best = best

    model.train()
    completed_steps = 0
    for epoch in range(epochs):
        for step, batch in tqdm( enumerate(training, start=1), total=trainingSteps):
            # outputs = model(**batch)
            # loss = outputs.loss
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            # print(inputs)
            loss = loss_fn(outputs, targets)
            loss = loss / gradient_accumulation_steps
            loss.backward()
            if step % gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                completed_steps += 1
        
        # evaluation = evaluate(model, validation, metric_for_early_stopping, is_called_from_training=True)
        # print(f"Tokom epohe {epoch+1} loss je bio {loss} sa akumuliranjem, tj. {loss*gradient_accumulation_steps} bez akumuliranja gradijenta, learning rate je {scheduler.get_last_lr()}")
        if enable_early_stopping:
            earlyStopping.end_time = time.perf_counter()
            # if best is None or evaluation > best:
            #     best = evaluation[metric_for_early_stopping]
            #     torch.save(model.state_dict(), model_weights_path) # NOTE: Mozda puca

            # if earlyStopping.step(evaluation[metric_for_early_stopping]) or earlyStopping.time_ran_out():
            #     return

# Training

## Initialization

In [17]:
# asd = np.array([ img.reshape(3,400,400) for i in range(8) ])

In [18]:
# asd.shape

In [19]:
# X, Y = asd[:,0,:,:].reshape(1,1,400,400), asd[:,1:,:,:].reshape(1,2,400,400)

In [20]:
number_of_epochs = 3
batch_size = 64

model = ColorizerModel()
if use_model_weights:
    model.load_state_dict(torch.load(model_weights_path))
optimizer = Lion(model.parameters(), lr=3.67*0.001) 
# optimizer = torch.optim.RMSprop(model.parameters(), lr=5e-5)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
training = DataLoader(training, batch_size=batch_size, shuffle=True)
# learning_rate_scheduler = lr_scheduler.LinearLR(optimizer=optimizer, start_factor=0.9,end_factor=1/5,total_iters=number_of_epochs * len(training))
learning_rate_scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=number_of_epochs * len(training), eta_min=0.1)
metric = FrechetInceptionDistance(feature=64,normalize=True)
validation = DataLoader(validation, batch_size=batch_size, shuffle=True)

## Trening

In [21]:
fit(model=model,optimizer=optimizer,training=training, validation=validation,scheduler=learning_rate_scheduler,epochs=number_of_epochs,device=device, gradient_accumulation_steps=8, metric_for_early_stopping=metric, enable_early_stopping=False)

  0%|          | 0/798 [00:00<?, ?it/s]

[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
(3, 224, 224)
be4 transform


  0%|          | 0/798 [00:00<?, ?it/s]


RuntimeError: "baddbmm_with_gemm" not implemented for 'Half'

In [None]:
for sth in enumerate(training):
    print(sth)

# Test

In [None]:
# TODO...

# Inference

In [None]:
X.shape

In [None]:
output.shape

In [None]:
asd = convert_lab_to_rgb( np.concatenate((X,output),axis=1))

In [None]:
asd[0,]

In [None]:
model.eval()
with torch.no_grad():
    output = model(torch.tensor(X).float().to(device)).cpu().numpy()
    output *= 128
    rgb_img = convert_lab_to_rgb(np.concatenate((X,output), axis=1),denormalize=True)
    imsave("img_result.png", rgb_img[0].reshape(400,400,3))
    imsave("img_result_gray_version.png", (255*rgb2gray(rgb_img[0].reshape(400,400,3))).astype(np.uint8))
model.train()