In [None]:
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from importlib import reload
import loss
import data_generator
import model
import train
import visualize
reload(loss)
reload(visualize)
reload(model)
reload(data_generator)
reload(train)

from data_generator import make_dataloaders
from model import MainModel
from train import train_model, load_model, build_res_unet, pretrain_generator
from background_detection import load_processed_images

In [None]:
COLOR_SPACE = 'Lab' # Lab or HSL or YCbCr
path = "./data/part1"
l1loss = "SmoothL1Loss" # SmoothL1Loss or L1Loss
model_path = f"./models/model_pretrained_noBG_part1_{COLOR_SPACE}_{l1loss}.pt"
paths = load_processed_images('./background_scores/filtered_part1')
np.random.seed(123)
paths_subset = np.random.choice(paths, 7500, replace=False) # choosing 1000 images randomly
rand_idxs = np.random.permutation(7500)
train_idxs = rand_idxs[:6000] # choosing the first 800 as training set
val_idxs = rand_idxs[1500:] # choosing last 200 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))

In [None]:
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

In [None]:
train_dl = make_dataloaders(paths=train_paths, split='train', color_space=COLOR_SPACE)
val_dl = make_dataloaders(paths=val_paths, split='val', color_space=COLOR_SPACE)

data = next(iter(train_dl))
known_channels, unknown_channels_ = data['known_channel'], data['unknown_channels']
print(known_channels.shape, unknown_channels_.shape)
print(len(train_dl), len(val_dl))

In [None]:
net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt = torch.optim.Adam(net_G.parameters(), lr=1e-4)
criterion = torch.nn.L1Loss()        
pretrain_generator(net_G, train_dl, opt, criterion, 20, device)
pretrained_model = f"res18-unet_noBG_{COLOR_SPACE}.pt"

In [None]:
torch.save(net_G.state_dict(), pretrained_model)

In [None]:
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load(pretrained_model, map_location=device))

In [None]:
loaded_model = MainModel(net_G=net_G, L1LossType=l1loss)
# _, loss_meter_dict = load_model(model_path, loaded_model)
# train_model(loaded_model, train_dl, val_dl, COLOR_SPACE, 150, 70, loss_meter_dict=loss_meter_dict, save_path=model_path)
train_model(loaded_model, train_dl, val_dl, COLOR_SPACE, 150, 70, save_path=model_path)


In [None]:
from visualize import visualize

loaded_model = MainModel(net_G=net_G)
_, loss_meter_dict = load_model(model_path, loaded_model)
for i, data in enumerate(iter(val_dl)):
    visualize(loaded_model, data, COLOR_SPACE, save=False)
    if i == 4:
        break