In [None]:
import copy

import matplotlib.pyplot as plt
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from tqdm.notebook import tqdm
import seaborn as sns

sns.set_theme()
plt.rcParams["figure.figsize"] = (16, 9)

from nircoloring.config import SERENGETI_NIR_INCANDESCENT_DATASET_OUT
from nircoloring.dataset.caltech import SerengetiMetaDataSource

In [None]:
import os

networks_to_prepare = [("cycle_gan_serengeti_inc_0_000015_0_000045", "200")]

for name, epoch in networks_to_prepare:
    for phase in ["train", "val", "test"]:
        directory = f"/tmp/borstelmanna0/results/{phase}/{name}/test_{epoch}/images"
        for file in tqdm(os.listdir(directory)):
            filename_without_extension, extension = os.path.splitext(file)
            filepath = os.path.join(directory, file)
            if filename_without_extension.endswith("_real"):
                os.remove(filepath)
            if filename_without_extension.endswith("_fake"):
                new_filename = filename_without_extension.removesuffix("_fake") + ".jpg"
                os.rename(filepath, os.path.join(directory, new_filename))

In [None]:
networks_to_prepare = [("cut_serengeti_inc_0_000002", "400")]

for name, epoch in networks_to_prepare:
    for phase in ["train", "val", "test"]:
        directory = f"/tmp/borstelmanna0/cut-results/{name}/{phase}_{epoch}/images/fake_B"
        for file in tqdm(os.listdir(directory)):
            filename_without_extension, extension = os.path.splitext(file)
            new_filename = filename_without_extension + ".jpg"
            filepath = os.path.join(directory, file)
            os.rename(filepath, os.path.join(directory, new_filename))

In [None]:
from PIL import Image
from torch.utils.data.dataset import T_co
from torch.utils.data import Dataset
from nircoloring.dataset.caltech import AbstractMetaDataSource


def parse_label_map(meta_data_source: AbstractMetaDataSource, entries: pd.Series):
    images = meta_data_source.load_images()
    annotations = meta_data_source.load_annotations()

    images["base_file_name"] = images["file_name"].str.lower().str.rpartition("/")[2]
    images = images[images["base_file_name"].isin(entries.str.lower())]
    labels = annotations.groupby("image_id")["category_id"].unique()
    images = images.merge(labels, how="left", left_on="id", right_on="image_id")
    return images[["base_file_name", "category_id"]].set_index("base_file_name")


class SerengetiDataset(Dataset):
    def __init__(self, path, meta_data_source, transform, size) -> None:
        self.path = path
        self.entries: pd.Series = pd.Series(os.listdir(path))
        self.label_map: pd.DataFrame = parse_label_map(meta_data_source, self.entries)
        self.transform = transform
        self.size = size

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, index) -> T_co:
        filename = self.entries[index]
        img_file_path = os.path.join(self.path, filename)
        img = Image.open(img_file_path).convert('RGB')
        img = self.transform(img)

        labels_indices = self.label_map["category_id"][filename.lower()]

        return {
            'image': img,
            'labels': int(labels_indices[0])
        }

In [None]:
transformations = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

nir_train_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "trainA"),
                                     transform=transformations,
                                     meta_data_source=SerengetiMetaDataSource(),
                                     size=61)

nir_test_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "testA"),
                                    transform=transformations,
                                    meta_data_source=SerengetiMetaDataSource(),
                                    size=61)

nir_val_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "valA"),
                                   transform=transformations,
                                   meta_data_source=SerengetiMetaDataSource(),
                                   size=61)

inc_train_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "trainB"),
                                     transform=transformations,
                                     meta_data_source=SerengetiMetaDataSource(),
                                     size=61)

inc_test_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "testB"),
                                    transform=transformations,
                                    meta_data_source=SerengetiMetaDataSource(),
                                    size=61)

inc_val_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "valB"),
                                   transform=transformations,
                                   meta_data_source=SerengetiMetaDataSource(),
                                   size=61)

In [None]:
inc_test_dataset = SerengetiDataset(path=os.path.join(SERENGETI_NIR_INCANDESCENT_DATASET_OUT, "testB"),
                                    transform=transformations,
                                    meta_data_source=SerengetiMetaDataSource(),
                                    size=61)

In [None]:
cycle_gan_train_dataset = SerengetiDataset(
    path="/tmp/borstelmanna0/results/train/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
    transform=transformations,
    meta_data_source=SerengetiMetaDataSource(),
    size=61
)

cycle_gan_test_dataset = SerengetiDataset(
    path="/tmp/borstelmanna0/results/test/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
    transform=transformations,
    meta_data_source=SerengetiMetaDataSource(),
    size=61
)

cycle_gan_val_dataset = SerengetiDataset(
    path="/tmp/borstelmanna0/results/val/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
    transform=transformations,
    meta_data_source=SerengetiMetaDataSource(),
    size=61
)

In [None]:
cut_train_dataset = SerengetiDataset(
    path="/tmp/borstelmanna0/cut-results/cut_serengeti_inc_0_000002/train_400/images/fake_B/",
    transform=transformations,
    meta_data_source=SerengetiMetaDataSource(),
    size=61
)

cut_test_dataset = SerengetiDataset(
    path="/tmp/borstelmanna0/cut-results/cut_serengeti_inc_0_000002/test_400/images/fake_B/",
    transform=transformations,
    meta_data_source=SerengetiMetaDataSource(),
    size=61
)

cut_val_dataset = SerengetiDataset(
    path="/tmp/borstelmanna0/cut-results/cut_serengeti_inc_0_000002/val_400/images/fake_B/",
    transform=transformations,
    meta_data_source=SerengetiMetaDataSource(),
    size=61
)

In [None]:
def accuracy(outputs, target: torch.Tensor):
    _, preds = torch.max(outputs, 1)
    return torch.sum(target == preds).item() / len(preds)

In [None]:
def train(model: torch.nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module,
          device):
    model.train()
    counter = 0
    train_running_loss = 0.0
    acc_running = 0.0
    for i, data in enumerate(tqdm(dataloader, "Training")):
        counter += 1
        data, target = data['image'].to(device), data['labels'].to(device)
        optimizer.zero_grad()

        outputs = model(data)
        loss = criterion(outputs, target)

        train_running_loss += loss.item()

        loss.backward()
        optimizer.step()

        acc_running += accuracy(outputs, target)

    train_loss = train_running_loss / counter
    acc = acc_running / counter
    return train_loss, acc

In [None]:
def validate(model: torch.nn.Module, dataloader: DataLoader, criterion: torch.nn.Module, device):
    model.eval()
    counter = 0
    val_running_loss = 0.0
    acc_running = 0.0

    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader, "Validating")):
            counter += 1
            data, target = data['image'].to(device), data['labels'].to(device)

            outputs = model(data)
            loss = criterion(outputs, target)

            acc_running += accuracy(outputs, target)
            val_running_loss += loss.item()

        val_loss = val_running_loss / counter
        acc = acc_running / counter

        return val_loss, acc

In [None]:
def plot(train_loss, val_loss, train_acc, val_acc):
    fig, axs = plt.subplots(1, 2)
    axs[0].plot(train_loss, label='train loss')
    axs[0].plot(val_loss, label='validation loss')
    axs[1].plot(train_acc, label='train accuracy')
    axs[1].plot(val_acc, label='test accuracy')
    axs[0].set_xlabel('Epochs')
    axs[0].set_ylabel('Loss')
    axs[0].legend()
    axs[1].set_xlabel('Epochs')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    plt.show()


def fit(model, epochs, optimizer, criterion, train_loader, val_loader, plot_freq=5):
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []

    best_acc = 0.0
    best_model = model.state_dict()

    for epoch in tqdm(range(epochs), "Processing epochs"):
        train_epoch_loss, train_acc_epoch = train(model, train_loader, optimizer, criterion, device)
        val_epoch_loss, val_acc_epoch = validate(model, val_loader, criterion, device)

        train_loss.append(train_epoch_loss)
        train_acc.append(train_acc_epoch)

        val_loss.append(val_epoch_loss)
        val_acc.append(val_acc_epoch)

        print(f"Train Loss: {train_epoch_loss:.4f}, Acc: {train_acc_epoch}\n")
        print(f"Val Loss: {val_epoch_loss:.4f}, Acc: {val_acc_epoch}\n")

        if val_acc_epoch > best_acc:
            best_acc = val_acc_epoch
            best_model = copy.deepcopy(model.state_dict())

        if epoch % plot_freq == 0:
            plot(train_loss, val_loss, train_acc, val_acc)

    plot(train_loss, val_loss, train_acc, val_acc)
    model.load_state_dict(best_model)
    return model

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
nir_net = models.resnet50()

for param in nir_net.parameters():
    param.requires_grad = True

children_to_freeze = [nir_net.conv1, nir_net.layer1, nir_net.layer2]
for child in children_to_freeze:
    child.requires_grad_(False)

nir_net.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features=nir_net.fc.in_features, out_features=61),
)
nir_net.to(device)

lr = 0.001
epochs = 40
batch_size = 25
optimizer = torch.optim.Adam(nir_net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    nir_train_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_loader = DataLoader(
    nir_val_dataset,
    batch_size=batch_size,
    shuffle=False
)

nir_net = fit(nir_net, epochs, optimizer, criterion, train_loader, val_loader)

In [None]:
inc_net = models.resnet50()

for param in inc_net.parameters():
    param.requires_grad = True

children_to_freeze = [inc_net.conv1, inc_net.layer1]
for child in children_to_freeze:
    child.requires_grad_(False)

inc_net.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features=inc_net.fc.in_features, out_features=61),
)
inc_net.to(device)

lr = 0.0001
epochs = 40
batch_size = 25
optimizer = torch.optim.Adam(inc_net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    inc_train_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_loader = DataLoader(
    inc_val_dataset,
    batch_size=batch_size,
    shuffle=False
)

inc_net = fit(inc_net, epochs, optimizer, criterion, train_loader, val_loader)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

cycle_gan_net = models.resnet50()

for param in cycle_gan_net.parameters():
    param.requires_grad = True

children_to_freeze = [cycle_gan_net.conv1, cycle_gan_net.layer1, cycle_gan_net.layer2]
for child in children_to_freeze:
    child.requires_grad_(False)

cycle_gan_net.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_features=cycle_gan_net.fc.in_features, out_features=61))
cycle_gan_net.to(device)

lr = 0.001
epochs = 40
batch_size = 25
optimizer = torch.optim.Adam(cycle_gan_net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    cycle_gan_train_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_loader = DataLoader(
    cycle_gan_val_dataset,
    batch_size=batch_size,
    shuffle=False
)

cycle_gan_net = fit(cycle_gan_net, epochs, optimizer, criterion, train_loader, val_loader)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

cut_net = models.resnet50()

for param in cut_net.parameters():
    param.requires_grad = True

children_to_freeze = [cut_net.conv1, cut_net.layer1, cut_net.layer2]
for child in children_to_freeze:
    child.requires_grad_(False)

cut_net.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_features=cut_net.fc.in_features, out_features=61))
cut_net.to(device)

lr = 0.001
epochs = 40
batch_size = 25
optimizer = torch.optim.Adam(cut_net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    cut_train_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_loader = DataLoader(
    cut_val_dataset,
    batch_size=batch_size,
    shuffle=False
)

cut_net = fit(cut_net, epochs, optimizer, criterion, train_loader, val_loader)

In [None]:
import numpy as np


def test(model, test_loader, show_images=20):
    model.eval()

    running_acc = 0
    total = 0
    with torch.no_grad():
        for counter, data in enumerate(test_loader):
            image, target = data['image'].to(device), data['labels']
            outputs = model(image).detach().cpu()
            acc = accuracy(outputs, target)

            prob = torch.softmax(outputs, dim=1)[0]
            pred = torch.argmax(prob)

            running_acc += acc

            if total < show_images:
                image = image.squeeze(0)
                image = image.detach().cpu().numpy()
                image = np.transpose(image, (1, 2, 0))
                plt.imshow(image)
                plt.axis('off')
                plt.title(f"PREDICTED: {pred}\nACTUAL: {target.item()} ({prob[pred]})")
                plt.show()

            total += 1

    return running_acc / total

In [None]:
test_loader = DataLoader(nir_test_dataset, batch_size=1, shuffle=False)
test(nir_net, test_loader, show_images=3)

In [None]:
test_loader = DataLoader(inc_test_dataset, batch_size=1, shuffle=False)
test(inc_net, test_loader, show_images=3)

In [None]:
cycle_gan_test_loader = DataLoader(cycle_gan_test_dataset, batch_size=1, shuffle=False)
test(cycle_gan_net, cycle_gan_test_loader, show_images=3)

In [None]:
cut_test_loader = DataLoader(cut_test_dataset, batch_size=1, shuffle=False)
test(cut_net, cut_test_loader, show_images=3)

In [None]:
def test_assemble(model1, model2, test_loader1, test_loader_2, show_images=20):
    model1.eval()
    model2.eval()

    running_acc = 0
    total = 0
    with torch.no_grad():
        for _, (data1, data2) in enumerate(zip(test_loader1, test_loader_2)):
            image1, target1 = data1['image'].to(device), data1['labels']
            outputs1 = model1(image1).detach().cpu()
            prob1 = torch.softmax(outputs1, dim=1)[0]
            pred1 = torch.argmax(prob1)

            image2, target2 = data2['image'].to(device), data2['labels']
            outputs2 = model2(image2).detach().cpu()
            prob2 = torch.softmax(outputs2, dim=1)[0]
            pred2 = torch.argmax(prob2)

            if prob2[pred2] > prob1[pred1]:
                pred = pred2
            else:
                pred = pred1

            if pred == target1:
                running_acc += 1

            if total < show_images:
                image2 = image2.squeeze(0)
                image2 = image2.detach().cpu().numpy()
                image2 = np.transpose(image2, (1, 2, 0))
                plt.imshow(image2)
                plt.axis('off')
                plt.title(f"PREDICTED: {pred}\nACTUAL: {target1.item()} ({prob1[pred1]}, {prob2[pred2]})")
                plt.show()

            total += 1

    return running_acc / total

In [None]:
nir_test_loader = DataLoader(nir_test_dataset, batch_size=1, shuffle=False)
cycle_gan_test_loader = DataLoader(cycle_gan_test_dataset, batch_size=1, shuffle=False)

test_assemble(nir_net, cycle_gan_net, nir_test_loader, cycle_gan_test_loader, show_images=3)