In [3]:
import os 
os.chdir('../../')

In [6]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.nn.functional import cross_entropy
from torchvision import datasets, transforms

In [7]:
BATCH_SIZE = 256
RANDOM_SEED = 88
VALIDATION_FRACTION = 0.2 # pct of training data to use in validation

torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x116b59bb0>

In [17]:
train_transformer = transforms.Compose([
    transforms.Resize((40,40)),
    transforms.RandomCrop((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

test_transformer = transforms.Compose([
    transforms.Resize((40,40)),
    transforms.CenterCrop((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

train_data = datasets.CIFAR10(root='data', train=True, transform=train_transformer, download=True)
test_data = datasets.CIFAR10(root='data', train=False, transform=test_transformer, download=True)

n_valid = int(VALIDATION_FRACTION * 50000)
train_idxs = torch.arange(0, 50000 - n_valid)
valid_idxs = torch.arange(50000 - n_valid, 50000)

train_dataloader = DataLoader(
    train_data, 
    batch_size=BATCH_SIZE, 
    sampler=SubsetRandomSampler(train_idxs)
    )

valid_dataloader = DataLoader(
    train_data, 
    batch_size=BATCH_SIZE, 
    sampler=SubsetRandomSampler(valid_idxs)
    )

test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
class CNN(torch.nn.Module):
    '''
    Convolutional Neural Net
    '''
    def __init__(self, n_classes: int):
        super().__init__()

        self.n_classes = n_classes
        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.LeakyReLU(0.1, inplace=True),
            #
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.LeakyReLU(0.1, inplace=True),
            #
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.LeakyReLU(0.1, inplace=True),
            #
            torch.nn.Flatten(),
            torch.nn.Linear(4096, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.LeakyReLU(0.1, inplace=True),
            torch.nn.Dropout(0.5),
            #
            torch.nn.Linear(128, n_classes),
                )
    
    def forward(self, x):
        logits = self.layers(x)
        return logits

    def fit(
        self,
        train_dataloader: DataLoader,
        valid_dataloader: DataLoader,
        learning_rate: float,
        epochs: int,
        verbose=False,
    ) -> None:

        optim = torch.optim.Adam(self.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optim,
            factor=0.1,
            mode='max',
            verbose=True)
                
        self.train()

        for e in range(epochs):
            correct, n_examples = 0, 0

            for batch_idx, (x, y) in enumerate(train_dataloader):
                logits = self(x)
                loss = cross_entropy(logits, y)
                optim.zero_grad()

                loss.backward()
                optim.step()

                if verbose:
                    with torch.no_grad():
                        print(
                            f"epoch:{e:.0f}, batch:{batch_idx:.0f}. Loss={loss.item():.3f}"
                        )
                        _, yhat = torch.max(logits, dim=1)
                        correct += torch.sum(yhat == y)
                        n_examples += yhat.size(0)
            
            train_accuracy = correct / n_examples

            # calculate validation accuracy
            self.eval()
            correct, n_examples = 0, 0
            for x, y in valid_dataloader:
                logits = self(x)
                _, yhat = torch.max(logits, dim=1)
                correct += torch.sum(yhat == y)
                n_examples += yhat.size(0)
                valid_accuracy = correct / n_examples

            # scheduler
            scheduler.step(valid_accuracy)


            # Print train and valid accuracy on each epoch
            if verbose:
                print(
                    f"epoch:{e:.0f} over, training_acc={train_accuracy.item():.3f}, 'valid_acc={valid_accuracy.item():.3f}"
                    )

In [20]:
cnn = CNN(n_classes=10)

In [22]:
cnn.fit(train_dataloader, valid_dataloader, 0.1, epochs=50, verbose=True)

epoch:0, batch:0. Loss=3.401
epoch:0, batch:1. Loss=2.717
epoch:0, batch:2. Loss=2.390
epoch:0, batch:3. Loss=2.175
epoch:0, batch:4. Loss=2.191
epoch:0, batch:5. Loss=2.121
epoch:0, batch:6. Loss=2.006
epoch:0, batch:7. Loss=1.851
epoch:0, batch:8. Loss=1.910
epoch:0, batch:9. Loss=1.947
epoch:0, batch:10. Loss=1.884
epoch:0, batch:11. Loss=1.751
epoch:0, batch:12. Loss=1.927
epoch:0, batch:13. Loss=1.807
epoch:0, batch:14. Loss=1.686
epoch:0, batch:15. Loss=1.775
epoch:0, batch:16. Loss=1.640
epoch:0, batch:17. Loss=1.678
epoch:0, batch:18. Loss=1.562
epoch:0, batch:19. Loss=1.692
epoch:0, batch:20. Loss=1.693
epoch:0, batch:21. Loss=1.610
epoch:0, batch:22. Loss=1.589
epoch:0, batch:23. Loss=1.589
epoch:0, batch:24. Loss=1.581
epoch:0, batch:25. Loss=1.548
epoch:0, batch:26. Loss=1.417
epoch:0, batch:27. Loss=1.705
epoch:0, batch:28. Loss=1.403
epoch:0, batch:29. Loss=1.548
epoch:0, batch:30. Loss=1.527
epoch:0, batch:31. Loss=1.486
epoch:0, batch:32. Loss=1.525
epoch:0, batch:33. L