In [35]:
# %%capture
# !pip install jcopdl

# Import Common Packages

In [36]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Custom Dataset & Dataloader

In [37]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [38]:
bs = 1024

mnist_path = './data/'

transform = transforms.Compose([transforms.ToTensor()])

train_set = datasets.ImageFolder(root=mnist_path + 'train/', transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)

test_set = datasets.ImageFolder(root=mnist_path + 'test/', transform=transform)
testloader = DataLoader(test_set, batch_size=bs, shuffle=True)

# Architecture & Config

In [43]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.Mish(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.Mish(),
            nn.MaxPool2d(2, 2),  
            nn.Flatten()
            
        )
        
        self.fc = nn.Sequential(
            nn.Linear(1568, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(256, 10),
            nn.LogSoftmax(1)
        )
        
    def forward(self, x):
        return self.fc(self.conv(x))

In [44]:
model = MNIST_CNN()
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.005)
callback = Callback(model, config=None, outdir="model")

In [45]:
from tqdm.auto import tqdm

def loop_fn(mode, dataset, dataloader, model, criterion, optimizer, device):
    if mode == "train":
        model.train()
    elif mode == "test":
        model.eval()
    cost = correct = 0
    for feature, target in tqdm(dataloader, desc=mode.title()):
        feature, target = feature.to(device), target.to(device)
        output = model(feature)
        loss = criterion(output, target)
        
        if mode == "train":
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        cost += loss.item() * feature.shape[0]
        correct += (output.argmax(1) == target).sum().item()
    cost = cost / len(dataset)
    acc = correct / len(dataset)
    return cost, acc

In [None]:
while True:
    train_cost = loop_fn("train", train_set, trainloader, model, criterion, optimizer, device)
    with torch.no_grad():
        test_cost = loop_fn("test", test_set, testloader, model, criterion, optimizer, device)
    
    # Logging
    callback.log(train_cost, test_cost)

    # Checkpoint
    callback.save_checkpoint()
        
    # Runtime Plotting
    callback.cost_runtime_plotting()
    
    # Early Stopping
    if callback.early_stopping(model, monitor="test_cost"):
        callback.plot_cost()
        break

Train:   0%|          | 0/59 [00:04<?, ?it/s]

In [None]:
torch.save(model.state_dict(), "model/weights.pth")

In [None]:
weights = torch.load("model/weights.pth", map_location="cpu")
model.load_state_dict(weights)
model = model.to(device);