In [2]:
from torchvision.datasets import MNIST, mnist, CIFAR10
from torchvision import transforms

import torch.nn.functional as F
import torch
from torch import nn
from torch.autograd import Variable
from torch.distributions import Categorical
from torch.utils.data import DataLoader

In [16]:
from tqdm.notebook import tqdm_notebook

import matplotlib.pyplot as plt
import seaborn as sns

from itertools import chain
import pandas as pd
from pathlib import Path

In [4]:
from ray import tune
from ray import train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle

In [5]:
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
class CustomTargetTransform:
    def __init__(self, num_classes=10):
        self.num_classes = num_classes

    def __call__(self, target):
        new_target = torch.zeros(self.num_classes, dtype=torch.float, device=device)
        new_target[target] = 1
        return new_target

transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x.float().to(device))
])

# data_loader = DataLoader(dataset, batch_size=800, shuffle=True)

In [7]:
# dataset = mnist.FashionMNIST("data", download=True, train=True, transform=transform, target_transform=CustomTargetTransform())
# dataset_target = mnist.FashionMNIST("data", download=True, train=False, transform=transforms.PILToTensor())
dataset = CIFAR10("data", download=True, train=True, transform=transform, target_transform=CustomTargetTransform())
dataset_target = CIFAR10("data", download=True, train=False, transform=transforms.PILToTensor())

Files already downloaded and verified
Files already downloaded and verified


In [8]:
target_data = torch.tensor(dataset_target.data).swapaxes(3, 1).float().to(device)
# target_data = torch.tensor(dataset_target.data).unsqueeze(1).float().to(device)
target_labels = torch.tensor(dataset_target.targets).float().to(device)
target_data.shape, target_labels.shape

(torch.Size([10000, 3, 32, 32]), torch.Size([10000]))

In [9]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, batch_norm1, batch_norm2, inter_channels, kernel_size1, kernel_size2, stride1, stride2, padding1=1, padding2=1):
        super().__init__()
        block = [
            nn.Conv2d(in_channels, inter_channels, kernel_size1, stride1, padding1),
        ]
        if batch_norm1:
            block.append(nn.BatchNorm2d(inter_channels))
        block.append(nn.ReLU())
        block.append(nn.Conv2d(inter_channels, in_channels, kernel_size2, stride2, padding2))
        if batch_norm2:
            block.append(nn.BatchNorm2d(in_channels))
        self.block = nn.Sequential(
            *block
        )
    
    def get_out_size(self, in_size):
        for layer in (self.block[0], self.block[-1] if type(self.block[-1]) == nn.Conv2d else self.block[-2]):
            in_size = (in_size - layer.kernel_size[0] + 2 * layer.padding[0]) // layer.stride[0] + 1
        return in_size

    def forward(self, x):
        return F.relu(x + self.block(x))


In [10]:
DATASET_NAME = str(dataset.__class__.__name__).split(".")[-1]
TRAIN_ID = 0
CHECKPOINT_DIR = Path("check")
DATASET_NAME

'CIFAR10'

In [15]:
def create_model(
        img_size,
        config: dict[str, int],
        *layers,
        linears: list[int],
        epoch_count=10,
        softmax=True,
        in_channels=1,
    ) -> tuple[nn.Sequential, nn.CrossEntropyLoss, torch.optim.SGD, SummaryWriter]:
    global TRAIN_ID
    torch.manual_seed(0)
    blocks = []
    outs = img_size
    for layer in layers:
        blocks.append(
            layer
        )
        tpe = type(layer)
        if tpe == ResBlock:
            outs = layer.get_out_size(outs)
        elif tpe == nn.Conv2d:
            outs = (outs - layer.kernel_size[0] + 2 * layer.padding[0]) // layer.stride[0] + 1
            in_channels = layer.out_channels
        elif tpe == nn.MaxPool2d:
            outs = (outs - layer.kernel_size) // layer.stride + 1
    outs = outs * outs * in_channels
    blocks.append(nn.Flatten(1))
    for layer in linears:
        blocks.append(nn.Linear(outs, layer))
        blocks.append(nn.ReLU())
        outs = layer
    blocks.append(nn.Linear(outs, 10))
    if softmax:
        blocks.append(nn.Softmax())
    model = nn.Sequential(
        *blocks,
    ).to(device)
    er_f = nn.CrossEntropyLoss()
    optim = torch.optim.AdamW(model.parameters(), lr=config["lr"])
    name = f"{DATASET_NAME}_{TRAIN_ID}"
    TRAIN_ID += 1
    print(name)
    return model, er_f, optim, name, config, epoch_count


In [None]:
def train_model(model, er_f, optim, name, config, epoch_count):
    checkpoint = get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "rb") as fp:
                checkpoint_state = pickle.load(fp)
            start_epoch = checkpoint_state["epoch"]
            model.load_state_dict(checkpoint_state["net_state_dict"])
            optim.load_state_dict(checkpoint_state["optimizer_state_dict"])
        writer = SummaryWriter(comment=name, log_dir=checkpoint_dir)
    else:
        start_epoch = 0
        writer = SummaryWriter(comment=name)

    
    data_loader = DataLoader(dataset, batch_size=config["batch_count"], shuffle=True)

    for epoch in tqdm_notebook(range(start_epoch, epoch_count)):
        running_loss = 0.0
        epoch_steps = 0
        for image, target in tqdm_notebook(data_loader, leave=False):
            # zero the parameter gradients
            optim.zero_grad()

            # forward + backward + optimize
            outputs = model(image)
            loss = er_f(outputs, target)
            loss.backward()
            optim.step()

            percent: torch.tensor = ((outputs.max(1).indices == target.max(1).indices).sum() / len(target))
            writer.add_scalar("loss", loss, epoch)
            writer.add_scalar("train_accuracy", percent, epoch)
            predicted = model(target_data)
            predicted_labels = predicted.max(1).indices
            percent: torch.tensor = ((predicted_labels == target_labels).sum() / len(target_labels))
            writer.add_scalar("test_accuracy", percent, epoch)
    
        checkpoint_data = {
            "epoch": epoch,
            "net_state_dict": model.state_dict(),
            "optimizer_state_dict": optim.state_dict(),
        }
        with tempfile.TemporaryDirectory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "wb") as fp:
                pickle.dump(checkpoint_data, fp)

            checkpoint = Checkpoint.from_directory(checkpoint_dir)
            train.report(
                {"loss": val_loss / val_steps, "accuracy": correct / total},
                checkpoint=checkpoint,
            )