In [1]:
import os
import time

from torchvision import transforms
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

from loader import load_png_images
from implementation import CNNClassifier, train_model

In [2]:
def prettify_duration(duration_in_seconds):
    s = duration_in_seconds % 60
    duration_in_seconds -= s
    duration_in_seconds //= 60

    m = duration_in_seconds % 60
    duration_in_seconds -= m
    duration_in_seconds //= 60

    return f"{duration_in_seconds}h {m}m {s}s"

In [3]:
PREFIX = "sample"

In [4]:
train_path = os.path.join(os.getcwd(), "data", PREFIX, "train")
valid_path = os.path.join(os.getcwd(), "data", PREFIX, "valid")
test_path = os.path.join(os.getcwd(), "data", PREFIX, "test")

In [None]:
transformations = {
    
}

In [6]:
for transformation_name, transformation in transformations.items():
    print("==============================================================================")
    print(f"transformation: {transformation_name}")
    print("------------------------------------------------------------------------------")
    
    train, t_n = load_png_images(train_path, transform=transformation, batch_size=64)
    valid, v_n = load_png_images(valid_path, transform=transformation, batch_size=2048)

    m_models = {
        "custom-cnn": CNNClassifier(t_n),
        "densnet": models.densenet121(weights="DenseNet121_Weights.IMAGENET1K_V1"),
        "wide-resnet": models.wide_resnet50_2(weights="Wide_ResNet50_2_Weights.DEFAULT")
    }
    m_models["densnet"].classifier = nn.Linear(m_models["densnet"].classifier.in_features, t_n)
    m_models["wide-resnet"].fc = nn.Linear(m_models["wide-resnet"].fc.in_features, t_n)
    
    for model_name, model in m_models.items():
        print(f"model: {model_name}")
        print()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        start = time.time()
        train_model(model, train, valid, criterion, optimizer, epochs=10)
        end = time.time()

        print(f"training took {prettify_duration(end - start)}")
        print()

transformation: no-transform
------------------------------------------------------------------------------
model: custom-cnn

Epoch 1/10, Loss: 0.6950, Validation Accuracy: 57.60%
Epoch 2/10, Loss: 0.6724, Validation Accuracy: 63.00%
Epoch 3/10, Loss: 0.6240, Validation Accuracy: 66.40%
Epoch 4/10, Loss: 0.5965, Validation Accuracy: 67.20%
Epoch 5/10, Loss: 0.5664, Validation Accuracy: 68.00%
Epoch 6/10, Loss: 0.5296, Validation Accuracy: 69.80%
Epoch 7/10, Loss: 0.5240, Validation Accuracy: 67.60%
Epoch 8/10, Loss: 0.5225, Validation Accuracy: 73.00%
Epoch 9/10, Loss: 0.4393, Validation Accuracy: 73.40%
Epoch 10/10, Loss: 0.3915, Validation Accuracy: 74.20%
training took 0.0h 1.0m 10.306885242462158s

model: densnet

Epoch 1/10, Loss: 0.6440, Validation Accuracy: 70.80%
Epoch 2/10, Loss: 0.3695, Validation Accuracy: 73.20%
Epoch 3/10, Loss: 0.2403, Validation Accuracy: 78.80%
Epoch 4/10, Loss: 0.0825, Validation Accuracy: 75.40%
Epoch 5/10, Loss: 0.0715, Validation Accuracy: 77.20%
E