# Black magic (hehehe)

In [4]:
import torch
import numpy as np

def merge_two_conv(conv1, conv2):
    kernel_size_1 = np.array(conv1.weight.size()[-2:])
    kernel_size_2 = np.array(conv2.weight.size()[-2:])
    kernel_size_merged = kernel_size_1 + kernel_size_2 - 1


    in_channels = conv1.weight.size()[1]
    out_channels = conv2.weight.size()[0]
    inner_channels = conv1.weight.size()[0]

    new_conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size_merged)
    padding = [kernel_size_2[0]-1, kernel_size_2[1]-1]

    new_conv.weight.data = torch.conv2d(conv1.weight.data.permute(1, 0, 2, 3),
                                        conv2.weight.data.flip(-1, -2),
                                        padding=padding).permute(1, 0, 2, 3)


    add_x = torch.ones(1, inner_channels, *kernel_size_2)
    add_x *= conv1.bias.data[None, :, None, None]

    new_conv.bias.data = torch.conv2d(add_x,
                                      conv2.weight.data).flatten()

    new_conv.bias.data += conv2.bias.data
    return new_conv

In [7]:
conv1 = torch.nn.Conv2d(2, 3, 3)
conv2 = torch.nn.Conv2d(3, 5, (4, 5))
new_conv = merge_two_conv(conv1, conv2)

x = torch.randn([1, 2, 9, 9])

assert (torch.abs(conv2(conv1(x)) - new_conv(x)) < 1e-6).min()
conv1.bias.data[None, :, None, None]

torch.Size([5, 3, 4, 5])

In [None]:
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
import torchvision.datasets
from torchvision.transforms import transforms



transformations = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder("./characters", transformations)
images_means = torch.empty(size=(len(dataset), 3))
images_stds = torch.empty(size=(len(dataset), 3))
labels_means = torch.empty(size=(len(dataset), 3))
labels_stds = torch.empty(size=(len(dataset), 3))

for ind, (image, label) in enumerate(tqdm(dataset)):
    images_means[ind] = (torch.mean(image, dim=[1,2]))
    images_stds[ind] = (torch.std(image, dim=[1,2]))


images_mean = images_means.mean(dim=[0])
images_std =  images_stds.mean(dim=[0])


print(f"Images means: {images_mean}")
print(f"Images std: {images_std}")


In [None]:
import os

from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets
import torch
import PIL
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
number_of_labels = 42
learning_rate = 0.001
num_epochs = 150
classes = ('abraham_grampa_simpson',
            'agnes_skinner',
            'apu_nahasapeemapetilon',
            'barney_gumble',
            'bart_simpson',
            'carl_carlson',
            'charles_montgomery_burns',
            'chief_wiggum',
            'cletus_spuckler',
            'comic_book_guy',
            'disco_stu',
            'edna_krabappel',
            'fat_tony',
            'gil',
            'groundskeeper_willie',
            'homer_simpson',
            'kent_brockman',
            'krusty_the_clown',
            'lenny_leonard',
            'lionel_hutz',
            'lisa_simpson',
            'maggie_simpson',
            'marge_simpson',
            'martin_prince',
            'mayor_quimby',
            'milhouse_van_houten',
            'miss_hoover',
            'moe_szyslak',
            'ned_flanders',
            'nelson_muntz',
            'otto_mann',
            'patty_bouvier',
            'principal_skinner',
            'professor_john_frink',
            'rainier_wolfcastle',
            'ralph_wiggum',
            'selma_bouvier',
            'sideshow_bob',
            'sideshow_mel',
            'snake_jailbird',
            'troy_mcclure',
            'waylon_smithers')
class_encoder = {}
for i in range(len(classes)):
    class_encoder[classes[i]]=i
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_labels = os.listdir(img_dir)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        
        img_path = os.path.join(self.img_dir,self.img_labels[idx])
        image = PIL.Image.open(img_path)
        label = self.img_labels[idx]
        class_indicator = label.rfind('_')
        class_str = label[:class_indicator]
        label = class_encoder[class_str]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4622, 0.4075, 0.3524],std=[0.2167, 0.1961, 0.2246]),
    transforms.Resize((32,32),antialias=True)
])


full_dataset = torchvision.datasets.ImageFolder("./characters",transformations)
train_dataset,valid_dataset = torch.utils.data.random_split(full_dataset,[0.7, 0.3])
train_dataset, test_set = torch.utils.data.random_split(full_dataset,[0.8, 0.2])
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=16)
print("The number of images in a training set is: ", len(train_loader)*batch_size)

test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=16)
print("The number of images in a test set is: ", len(test_loader)*batch_size)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=16)
print("The number of images in validation set is: ",len(valid_loader)*batch_size)
print("The number of batches per epoch is: ", len(train_loader))


In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        
        self.conv1 = nn.Sequential(
                nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU()
        )
        self.conv2 = nn.Sequential(
                nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU()
        )
        self.pool = nn.MaxPool2d(2,2)
        self.conv4 = nn.Sequential(
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128),
            nn.ReLU(),
                nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.conv5 = nn.Sequential(
                nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128),
            nn.ReLU(),
                nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv6 = nn.Sequential(
                nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
            nn.ReLU(),
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.conv7 = nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
            nn.ReLU(),
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv8 = nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
            nn.ReLU(),
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.conv9 = nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
            nn.ReLU(),
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(2,2)
        self.conv10 = nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
            nn.ReLU(),
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(512, 42)

    def forward(self, input):
        output = self.conv(input)
        output = self.pool(output)
        output = output.view(-1, 512)
        output = self.fc1(output)
        return output

model = Network().to(device)

In [None]:
from torch.optim import SGD
loss_fn = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=3, verbose=True, threshold=1e-2
)

In [None]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [None]:
from torch.autograd import Variable
import tqdm

def saveModel():
    path = "./simpsons.pth"
    torch.save(model.state_dict(), path)

def testAccuracy():
    
    model.eval()
    metric = torchmetrics.F1Score(task="multiclass", num_classes=42).to(device)
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            metric(predicted, labels)
    f1 = metric.compute()
    return f1
loss_metric =[]
recall_metric=[]
accuracy_metric=[]
lr_metric=[]

def train():
    
    best_f1 = 0.0
    print("The model will be running on", device, "device")
    comment = f' batch_size = {batch_size} lr = {learning_rate}'
    tb = SummaryWriter(comment=comment)
    for epoch in tqdm.tnrange(num_epochs,position=0,desc="Epochs"):
        losses = []
        total_correct=0;
        total_f1=0;
        for _, (images, labels) in enumerate(tqdm.tqdm_notebook(train_loader,position=1,desc="Batch iter",leave=True), 0):
            images = Variable(images.to(device))
            labels = Variable(labels.to(device))
            optimizer.zero_grad()
            outputs = model(images)
            total_correct+= get_num_correct(outputs, labels)
            loss = loss_fn(outputs, labels)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()
        mean_loss = sum(losses) / len(losses)
        scheduler.step(mean_loss)
        print(f"Loss at epoch {epoch} = {mean_loss}")
        f1 = testAccuracy()
        print(f"For epoch {epoch} F1: {f1}")
        tb.add_scalar("Loss", mean_loss, epoch)
        tb.add_scalar("Correct", total_correct, epoch)
        tb.add_scalar("F1", f1, epoch)

        if f1 > best_f1:
            saveModel()
            best_f1 = f1

    grid = torchvision.utils.make_grid(images)
    tb.add_image("images", grid)
    tb.add_graph(model, images)
    tb.close()

In [None]:
import torchmetrics

def testClassess():
    metric = torchmetrics.F1Score(task="multiclass", num_classes=42,average=None).to(device)
    with torch.no_grad():
        for data in valid_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            metric(predicted, labels)
    acc = metric.compute()
    for i in range(number_of_labels):
        print(f'F1 of {classes[i]} : {acc[i]}')

In [None]:
import torchinfo


torchinfo.summary(model, depth=2, input_size=(128, 3, 32,32), row_settings=["var_names"], verbose=0, col_names=[
"input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds", "trainable"])

In [None]:
if __name__ == "__main__":
    #train()
    print('Finished Training')
    model = Network().to(device)
    path = "simpsons.pth"
    model.load_state_dict(torch.load(path))
    testClassess()
    print(testAccuracy())