# SingleCNN Training Script

## Import libraries

In [45]:
import sys
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Import the the following path to use user-defined modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from utils.constants import TRAIN_ORIGINAL_FILE_PATH
from utils.utils import seed_worker
from models.single_cnn import SingleCNN

In [46]:
device = "cuda:6" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [31]:
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x23c09c24ab0>

## Retrive train data

In [32]:
train_data = None
with open(TRAIN_ORIGINAL_FILE_PATH, "rb") as file:
    train_data = pickle.load(file)

In [33]:
train_data.targets = torch.tensor(train_data.targets)
train_data.cluster = train_data.targets
train_data.targets = torch.zeros_like(train_data.targets)

In [34]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=128,
    shuffle=True,
    num_workers=2,
    worker_init_fn=seed_worker,
    generator=g,
)

## Retrive test data

In [35]:
test_data = None
with open(TRAIN_ORIGINAL_FILE_PATH, "rb") as file:
    test_data = pickle.load(file)

In [36]:
test_data.targets = torch.tensor(test_data.targets)
test_data.cluster = test_data.targets
test_data.targets = torch.zeros_like(test_data.targets)

In [37]:
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=100,
    shuffle=True,
    num_workers=2,
    worker_init_fn=seed_worker,
    generator=g,
)

## Create the model

In [38]:
model = SingleCNN(input_channels=3, num_classes=10)
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

## Criterion

In [39]:
criterion = nn.CrossEntropyLoss()

## Train the model

In [40]:
best_acc = 0  # best test accuracy
best_acc_list = []
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [41]:
def train(epoch):
    global model, optimizer, scheduler, criterion
    print(f"\nEpoch: {epoch}")

    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(f"Epoch: {epoch}, Batch: {batch_idx}")

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        # for optim in optimizers:
        #     optim.zero_grad()

        # if args.mixture:
        #     outputs, _, loss, _ = net(inputs)
        #     loss = criterion(outputs, targets) + 0.01 * loss
        # else:
        #     if args.model == "resnet18":
        #         outputs, _ = net(inputs)
        #     else:
        #         outputs = net(inputs)

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        optimizer.step()

        # for optim in optimizers:
        #     optim.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

In [42]:
def test(epoch):
    global best_acc, model, optimizer, scheduler, criterion

    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets, clusters = (inputs.to(device), targets.to(device))
            # if args.mixture:
            #     outputs, select0, _, _ = net(inputs)
            # else:
            #     if args.model == "resnet18":
            #         outputs, _ = net(inputs)
            #     else:
            #         outputs = net(inputs)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.0 * correct / total
    if acc > best_acc:
        print("Saving..")
        state = {
            "net": model.state_dict(),
            "acc": acc,
            "epoch": epoch,
        }
        if not os.path.isdir("checkpoint"):
            os.mkdir("checkpoint")
        torch.save(state, "./checkpoint/ckpt.pth")
        best_acc = acc

In [43]:
for epoch in range(0, 1):
    train(epoch)
    test(epoch)
    scheduler.step()

best_acc_list.append(best_acc)
best_acc = 0

print(
    f"Average accuracy: {np.mean(best_acc_list)} \t standard deviation: {np.std(best_acc_list)}"
)


Epoch: 0
Epoch: 0, Batch: 0
Epoch: 0, Batch: 1
Epoch: 0, Batch: 2
Epoch: 0, Batch: 3
Epoch: 0, Batch: 4
Epoch: 0, Batch: 5
Epoch: 0, Batch: 6
Epoch: 0, Batch: 7
Epoch: 0, Batch: 8
Epoch: 0, Batch: 9
Epoch: 0, Batch: 10
Epoch: 0, Batch: 11
Epoch: 0, Batch: 12
Epoch: 0, Batch: 13
Epoch: 0, Batch: 14
Epoch: 0, Batch: 15
Epoch: 0, Batch: 16
Epoch: 0, Batch: 17
Epoch: 0, Batch: 18
Epoch: 0, Batch: 19
Epoch: 0, Batch: 20
Epoch: 0, Batch: 21
Epoch: 0, Batch: 22
Epoch: 0, Batch: 23
Epoch: 0, Batch: 24
Epoch: 0, Batch: 25
Epoch: 0, Batch: 26
Epoch: 0, Batch: 27
Epoch: 0, Batch: 28
Epoch: 0, Batch: 29
Epoch: 0, Batch: 30
Epoch: 0, Batch: 31
Epoch: 0, Batch: 32
Epoch: 0, Batch: 33
Epoch: 0, Batch: 34
Epoch: 0, Batch: 35
Epoch: 0, Batch: 36
Epoch: 0, Batch: 37
Epoch: 0, Batch: 38
Epoch: 0, Batch: 39
Epoch: 0, Batch: 40
Epoch: 0, Batch: 41
Epoch: 0, Batch: 42
Epoch: 0, Batch: 43
Epoch: 0, Batch: 44
Epoch: 0, Batch: 45
Epoch: 0, Batch: 46
Epoch: 0, Batch: 47
Epoch: 0, Batch: 48
Epoch: 0, Batch: 49


KeyboardInterrupt: 