# SingleCNN Training Script

## Import libraries

In [1]:
import sys
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from utils.utils import seed_worker
from utils.constants import DATA_PATH
from utils.enums import RetrieveDataType
from data_retriever.data_retriever import DataRetriever
from models.single_cnn import SingleCNN

In [2]:
torch.__version__

'2.7.0+cu118'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x1cf17351030>

## Retrive train data

In [5]:
data_retriever = DataRetriever(DATA_PATH)

In [6]:
train_data = data_retriever.get_data(retrieve_data_type=RetrieveDataType.TRAIN_NOISE)

train_data.targets = torch.tensor(train_data.targets)
train_data.cluster = train_data.targets
train_data.targets = torch.zeros_like(train_data.targets)

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=128,
    shuffle=True,
    num_workers=2,
    worker_init_fn=seed_worker,
    generator=g,
)

## Retrive test data

In [None]:
test_data = data_retriever.get_data(retrieve_data_type=RetrieveDataType.TEST_NOISE)

test_data.targets = torch.tensor(test_data.targets)
test_data.cluster = test_data.targets
test_data.targets = torch.zeros_like(test_data.targets)

In [9]:
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=100,
    shuffle=True,
    num_workers=2,
    worker_init_fn=seed_worker,
    generator=g,
)

## Create the model

In [10]:
model = SingleCNN(input_channels=3, num_classes=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

## Criterion

In [11]:
criterion = nn.CrossEntropyLoss()

## Train the model

In [12]:
best_acc = 0  # best test accuracy
best_acc_list = []
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [13]:
def train(epoch):
    print(f"\nTraining - Epoch: {epoch}")

    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(f"Training - Epoch: {epoch}, Batch: {batch_idx}")

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # for optim in optimizers:
        #     optim.zero_grad()

        # if args.mixture:
        #     outputs, _, loss, _ = net(inputs)
        #     loss = criterion(outputs, targets) + 0.01 * loss
        # else:
        #     if args.model == "resnet18":
        #         outputs, _ = net(inputs)
        #     else:
        #         outputs = net(inputs)

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        optimizer.step()

        # for optim in optimizers:
        #     optim.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

In [14]:
def test(epoch):
    global best_acc
    print(f"\nTesting - Epoch: {epoch}")

    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            print(f"Testing - Epoch: {epoch}, Batch: {batch_idx}")

            inputs, targets = (inputs.to(device), targets.to(device))
            # if args.mixture:
            #     outputs, select0, _, _ = net(inputs)
            # else:
            #     if args.model == "resnet18":
            #         outputs, _ = net(inputs)
            #     else:
            #         outputs = net(inputs)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print(f"Correct: {correct}, Total: {total}")

    # Save checkpoint.
    acc = 100.0 * correct / total
    if acc > best_acc:
        print("Saving...")
        state = {
            "net": model.state_dict(),
            "acc": acc,
            "epoch": epoch,
        }
        if not os.path.isdir("checkpoint"):
            os.mkdir("checkpoint")
        torch.save(state, "./checkpoint/ckpt.pth")
        best_acc = acc

In [15]:
for epoch in range(0, 1):
    train(epoch)
    test(epoch)
    scheduler.step()

best_acc_list.append(best_acc)
best_acc = 0

print(
    f"Average accuracy: {np.mean(best_acc_list)} \t standard deviation: {np.std(best_acc_list)}"
)


Training - Epoch: 0
Training - Epoch: 0, Batch: 0
Training - Epoch: 0, Batch: 1
Training - Epoch: 0, Batch: 2
Training - Epoch: 0, Batch: 3
Training - Epoch: 0, Batch: 4
Training - Epoch: 0, Batch: 5
Training - Epoch: 0, Batch: 6
Training - Epoch: 0, Batch: 7
Training - Epoch: 0, Batch: 8
Training - Epoch: 0, Batch: 9
Training - Epoch: 0, Batch: 10
Training - Epoch: 0, Batch: 11
Training - Epoch: 0, Batch: 12
Training - Epoch: 0, Batch: 13
Training - Epoch: 0, Batch: 14
Training - Epoch: 0, Batch: 15
Training - Epoch: 0, Batch: 16
Training - Epoch: 0, Batch: 17
Training - Epoch: 0, Batch: 18
Training - Epoch: 0, Batch: 19
Training - Epoch: 0, Batch: 20
Training - Epoch: 0, Batch: 21
Training - Epoch: 0, Batch: 22
Training - Epoch: 0, Batch: 23
Training - Epoch: 0, Batch: 24
Training - Epoch: 0, Batch: 25
Training - Epoch: 0, Batch: 26
Training - Epoch: 0, Batch: 27
Training - Epoch: 0, Batch: 28
Training - Epoch: 0, Batch: 29
Training - Epoch: 0, Batch: 30
Training - Epoch: 0, Batch: 