In [None]:
import torch
from torch import nn
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.optim import Adam


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [14]:
train_data = ImageFolder(root="/home/rynutty/Documents/ProgrammingProjects/CustomModels/datasets/brisc2025/classification_task/train", transform=transform)
train_dataloader = DataLoader(train_data, 24, shuffle=True)

images, labels = next(iter(train_dataloader))

print(images.shape)
print(labels.shape)

torch.Size([24, 3, 224, 224])
torch.Size([24])


In [15]:
test_data = ImageFolder("/home/rynutty/Documents/ProgrammingProjects/CustomModels/datasets/brisc2025/classification_task/test", transform=transform)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

images, labels = next(iter(test_dataloader))

print(images.shape)
print(labels.shape)

torch.Size([64, 3, 224, 224])
torch.Size([64])


In [7]:
class BriscCNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=32),
            nn.MaxPool2d(kernel_size=5, stride=2),
            nn.Dropout2d(p=0.25)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64),
            nn.MaxPool2d(kernel_size=5, stride=2),
            nn.Dropout2d(p=0.25)
        )

        dummy_ex = torch.randn((1, 3, 224, 224))
        dummy_ex = self.conv_block1(dummy_ex)
        dummy_ex = self.conv_block2(dummy_ex)
        in_features = dummy_ex.view(-1)
        
        self.dense = nn.Linear(in_features=in_features.shape[0], out_features=4)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x.view(x.size(0), -1)
        return self.dense(x)


In [8]:
model = BriscCNN()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()

In [74]:
epochs = 5

for epoch in range(1, epochs+1):
    print(f"Training Epoch {epoch}...")

    train_working_loss = 0
    train_total = 0
    train_correct = 0

    model.train()
    for image_batch, label_batch in train_dataloader:
        optimizer.zero_grad()
        logits = model(image_batch)
        loss = loss_fn(logits, label_batch)
        loss.backward()
        optimizer.step()

        train_working_loss += loss.item()
        train_total += len(image_batch)
        predicted = logits.max(1)
        train_correct += (predicted.indices == label_batch).sum().item()

    print(f"Avg Epoch Loss: {train_working_loss / train_total}, Accuracy: {train_correct / train_total}")
    
    test_working_loss = 0
    test_correct = 0
    test_total = 0

    print(f"Evaluating Epoch {epoch}...")

    model.eval()
    with torch.no_grad():
        for image_batch, label_batch in test_dataloader:
            logits = model(image_batch)
            loss = loss_fn(logits, label_batch)

            test_working_loss += loss.item()
            test_total += len(image_batch)
            predicted = logits.max(1)
            test_correct += (predicted.indices == label_batch).sum().item()

        print(f"Avg Epoch Loss: {test_working_loss / test_total}, Accuracy: {test_correct / test_total}")


Training Epoch 1...
Avg Epoch Loss: 0.26096274664357766, Accuracy: 0.742
Evaluating Epoch 1...
Avg Epoch Loss: 0.04607568550109863, Accuracy: 0.807
Training Epoch 2...
Avg Epoch Loss: 0.09800977948989013, Accuracy: 0.8782
Evaluating Epoch 2...
Avg Epoch Loss: 0.030572465658187865, Accuracy: 0.863
Training Epoch 3...
Avg Epoch Loss: 0.04117216487390573, Accuracy: 0.9372
Evaluating Epoch 3...
Avg Epoch Loss: 0.029527729652822018, Accuracy: 0.881
Training Epoch 4...
Avg Epoch Loss: 0.030269136799403896, Accuracy: 0.951
Evaluating Epoch 4...
Avg Epoch Loss: 0.023087321817874908, Accuracy: 0.905
Training Epoch 5...
Avg Epoch Loss: 0.01917389002093844, Accuracy: 0.9632
Evaluating Epoch 5...
Avg Epoch Loss: 0.02970124664902687, Accuracy: 0.887


In [9]:
epochs = 5

for epoch in range(1, epochs+1):
    print(f"Training Epoch {epoch}...")

    train_working_loss = 0
    train_total = 0
    train_correct = 0

    model.train()
    for image_batch, label_batch in train_dataloader:
        optimizer.zero_grad()
        logits = model(image_batch)
        loss = loss_fn(logits, label_batch)
        loss.backward()
        optimizer.step()

        train_working_loss += loss.item()
        train_total += len(image_batch)
        predicted = logits.max(1)
        train_correct += (predicted.indices == label_batch).sum().item()

    print(f"Avg Epoch Loss: {train_working_loss / train_total}, Accuracy: {train_correct / train_total}")
    
    test_working_loss = 0
    test_correct = 0
    test_total = 0

    print(f"Evaluating Epoch {epoch}...")

    model.eval()
    with torch.no_grad():
        for image_batch, label_batch in test_dataloader:
            logits = model(image_batch)
            loss = loss_fn(logits, label_batch)

            test_working_loss += loss.item()
            test_total += len(image_batch)
            predicted = logits.max(1)
            test_correct += (predicted.indices == label_batch).sum().item()

        print(f"Avg Epoch Loss: {test_working_loss / test_total}, Accuracy: {test_correct / test_total}")

Training Epoch 1...
Avg Epoch Loss: 1.2013620707914845, Accuracy: 0.7252
Evaluating Epoch 1...
Avg Epoch Loss: 0.33885960674285887, Accuracy: 0.749
Training Epoch 2...
Avg Epoch Loss: 0.7967201314616504, Accuracy: 0.8138
Evaluating Epoch 2...
Avg Epoch Loss: 0.271345986366272, Accuracy: 0.809
Training Epoch 3...
Avg Epoch Loss: 0.6779656310094432, Accuracy: 0.8434
Evaluating Epoch 3...
Avg Epoch Loss: 0.3037311329841614, Accuracy: 0.798
Training Epoch 4...
Avg Epoch Loss: 0.6330756516188865, Accuracy: 0.851
Evaluating Epoch 4...
Avg Epoch Loss: 0.20187254667282103, Accuracy: 0.803
Training Epoch 5...
Avg Epoch Loss: 0.5169731878939379, Accuracy: 0.8714
Evaluating Epoch 5...
Avg Epoch Loss: 0.22706766176223755, Accuracy: 0.852


In [11]:
class BriscClassification(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=32),
            nn.MaxPool2d(kernel_size=4, stride=2),
            nn.Dropout2d(p=0.2)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64),
            nn.MaxPool2d(kernel_size=4, stride=2),
            nn.Dropout(p=0.2)
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=128),
            nn.MaxPool2d(kernel_size=4, stride=2),
            nn.Dropout(p=0.2)
        )

        dummy_ex = torch.randn((1, 3, 224, 224))
        dummy_ex = self.conv_block1(dummy_ex)
        dummy_ex = self.conv_block2(dummy_ex)
        dummy_ex = self.conv_block3(dummy_ex)
        dummy_ex = torch.flatten(dummy_ex, start_dim=1, end_dim=-1)

        self.fc = nn.Linear(in_features=dummy_ex.size(1), out_features=4)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        return self.fc(x)

In [16]:
model = BriscClassification()
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()

In [17]:
epochs = 7

for epoch in range(1, epochs+1):
    print(f"Training Epoch {epoch}...")

    train_working_loss = 0
    train_total = 0
    train_correct = 0

    model.train()
    for image_batch, label_batch in train_dataloader:
        optimizer.zero_grad()
        logits = model(image_batch)
        loss = loss_fn(logits, label_batch)
        loss.backward()
        optimizer.step()

        train_working_loss += loss.item()
        train_total += len(image_batch)
        predicted = logits.max(1)
        train_correct += (predicted.indices == label_batch).sum().item()

    print(f"Avg Epoch Loss: {train_working_loss / train_total}, Accuracy: {train_correct / train_total}")
    
    test_working_loss = 0
    test_correct = 0
    test_total = 0

    print(f"Evaluating Epoch {epoch}...")

    model.eval()
    with torch.no_grad():
        for image_batch, label_batch in test_dataloader:
            logits = model(image_batch)
            loss = loss_fn(logits, label_batch)

            test_working_loss += loss.item()
            test_total += len(image_batch)
            predicted = logits.max(1)
            test_correct += (predicted.indices == label_batch).sum().item()

        print(f"Avg Epoch Loss: {test_working_loss / test_total}, Accuracy: {test_correct / test_total}")

Training Epoch 1...
Avg Epoch Loss: 0.35713809664044527, Accuracy: 0.7114
Evaluating Epoch 1...
Avg Epoch Loss: 0.10761877560615539, Accuracy: 0.8
Training Epoch 2...
Avg Epoch Loss: 0.19071498335464857, Accuracy: 0.8174
Evaluating Epoch 2...
Avg Epoch Loss: 0.06916032886505127, Accuracy: 0.816
Training Epoch 3...
Avg Epoch Loss: 0.1213436385545084, Accuracy: 0.8562
Evaluating Epoch 3...
Avg Epoch Loss: 0.08258754253387451, Accuracy: 0.8
Training Epoch 4...
Avg Epoch Loss: 0.08168597337187371, Accuracy: 0.8842
Evaluating Epoch 4...
Avg Epoch Loss: 0.0687889153957367, Accuracy: 0.791
Training Epoch 5...
Avg Epoch Loss: 0.09230625452424245, Accuracy: 0.8776
Evaluating Epoch 5...
Avg Epoch Loss: 0.04644046488404274, Accuracy: 0.861
Training Epoch 6...
Avg Epoch Loss: 0.05958367577162935, Accuracy: 0.903
Evaluating Epoch 6...
Avg Epoch Loss: 0.0307027205824852, Accuracy: 0.876
Training Epoch 7...
Avg Epoch Loss: 0.05116089496614854, Accuracy: 0.9074
Evaluating Epoch 7...
Avg Epoch Loss: 0.