In [4]:
import os
os.environ["WANDB_MODE"] = "disabled"

In [54]:
#hparams.py
%%bash
cat > hparams.py << 'EOF'
config = dict(
    seed = 33,
    batch_size=64,
    learning_rate=1e-5,
    weight_decay=0.01,
    epochs=2,
    zero_init_residual=False,
)
EOF

In [55]:
# prepare_data.py
%%bash
cat > prepare_data.py << 'EOF'
from torchvision.datasets import CIFAR10

def main(root="CIFAR10"):
    # Скачиваем и сохраняем train и test
    CIFAR10(root=f"{root}/train", train=True, download=True)
    CIFAR10(root=f"{root}/test",  train=False, download=True)

if __name__ == "__main__":
    main()
EOF


In [69]:
# train.py
%%bash
cat > train.py << 'EOF'
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
import wandb

from tqdm import tqdm, trange
from hparams import config

def set_seed(seed=config["seed"]):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config["seed"])
wandb.init(config=config, project="effdl_example", name="baseline")

def compute_accuracy(preds, targets):
    return (targets == preds).float().mean().item()

def main():
    base = os.environ.get("CIFAR10_W", "CIFAR10")
    train_dir = os.path.join(base, "train")
    test_dir  = os.path.join(base, "test")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261)),
    ])

    train_dataset = CIFAR10(root=train_dir, train=True,  transform=transform, download=True)
    test_dataset  = CIFAR10(root=test_dir,  train=False, transform=transform, download=True)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=config["batch_size"],
                                               shuffle=True)
    test_loader  = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=config["batch_size"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = resnet18(pretrained=False,
                      num_classes=10,
                      zero_init_residual=config["zero_init_residual"])
    model.to(device)
    wandb.watch(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=config["learning_rate"],
                                  weight_decay=config["weight_decay"])

    for epoch in trange(config["epochs"], desc="Epoch"):
        for i, (images, labels) in enumerate(tqdm(train_loader, desc=f"Train {epoch}")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss    = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % 100 == 0:
                all_preds, all_lbls = [], []
                for imgs, lbls in test_loader:
                    imgs, lbls = imgs.to(device), lbls.to(device)
                    with torch.inference_mode():
                        out = model(imgs)
                        preds = torch.argmax(out, 1)
                        all_preds.append(preds.cpu())
                        all_lbls.append(lbls.cpu())
                acc = compute_accuracy(torch.cat(all_preds), torch.cat(all_lbls))
                wandb.log({'test_acc': acc, 'train_loss': loss.item()},
                          step=epoch * len(train_loader) + i)

    torch.save(model.state_dict(), "model.pt")
    with open("run_id.txt", "w") as f:
        f.write(wandb.run.id)

if __name__ == "__main__":
    main()
EOF

In [68]:
# compute_metrics.py
%%bash
cat > compute_metrics.py << 'EOF'
import os
import json
import random
import numpy as np
import torch
from argparse import ArgumentParser
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from hparams import config

def set_seed(seed=config["seed"]):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main(args=None):
    set_seed(config["seed"])
    base = os.environ.get("CIFAR10_W", "CIFAR10")
    test_dir = os.path.join(base, "test")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261)),
    ])

    ds = CIFAR10(root=test_dir, train=False, transform=transform, download=False)
    loader = torch.utils.data.DataLoader(ds, batch_size=config["batch_size"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = resnet18(pretrained=False,
                     num_classes=10,
                     zero_init_residual=config["zero_init_residual"])
    model.load_state_dict(torch.load("model.pt", map_location=device))
    model.to(device)

    correct = 0
    total   = len(ds)
    with torch.inference_mode():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            out   = model(imgs)
            preds = torch.argmax(out, 1)
            correct += (preds == lbls).sum().item()

    acc = correct / total
    with open("final_metrics.json", "w") as f:
        json.dump({"accuracy": acc}, f)

if __name__ == "__main__":
    parser = ArgumentParser()
    args = parser.parse_args()
    main(args)
EOF

In [58]:
# test_basic.py
%%bash
cat > test_basic.py << 'EOF'

import torch
import pytest
import os

from train import compute_accuracy

def test_arange_elems():
    arr = torch.arange(0, 10, dtype=torch.float)
    assert arr[-1].item() == 9.0

def test_div_zero():
    a = torch.zeros(1, dtype=torch.float)
    b = torch.ones(1, dtype=torch.float)
    result = b / a
    assert torch.isinf(result).any()


def test_div_zero_python():
    with pytest.raises(ZeroDivisionError):
        1/0

def test_accuracy():
    preds = torch.randint(0,2,size=(100,))
    targets = preds.clone()

    assert compute_accuracy(preds, targets) == 1.0

    preds = torch.tensor([1,2,3,0,0,0])
    targets = torch.tensor([1,2,3,4,5,6])

    assert compute_accuracy(preds, targets) == 0.5

@pytest.mark.parametrize("preds,targets,result",[
    (torch.tensor([1,2,3]), torch.tensor([1,2,3]), 1.0),
    (torch.tensor([1,2,3]), torch.tensor([0,0,0]), 0.0),
    (torch.tensor([1,2,3]), torch.tensor([1,2,0]), 2/3),
])
def test_accuracy_parametrized(preds, targets, result):
    actual = compute_accuracy(preds, targets)
    assert abs(actual - result) < 1e-5


EOF

In [10]:
!pytest test_basic.py -q

[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                                                  [100%][0m
[32m[32m[1m7 passed[0m[32m in 15.58s[0m[0m


In [73]:
# test_pipeline.py
%%bash
cat > test_pipeline.py << 'EOF'
import pytest
import torch
import os, json
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10

from train import compute_accuracy, main as train_main
from compute_metrics import main as metrics_main
from prepare_data import main as prepare_data
from hparams import config

@pytest.fixture(scope="session")
def data_root(tmp_path_factory):
    return str(tmp_path_factory.mktemp("CIFAR10"))

@pytest.fixture(scope="session")
def train_dataset(data_root):
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
    ])
    return CIFAR10(root=data_root, train=True, download=True, transform=tf)

@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_train_on_one_batch(device, train_dataset):
    if device=="cuda" and not torch.cuda.is_available():
        pytest.skip("no GPU")
    loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    imgs, lbls = next(iter(loader))
    imgs, lbls = imgs.to(device), lbls.to(device)
    model = torch.hub.load('pytorch/vision:v0.14.1',
                           'resnet18',
                           pretrained=False,
                           num_classes=10).to(device)
    opt  = torch.optim.AdamW(model.parameters(), lr=1e-5)
    crit = torch.nn.CrossEntropyLoss()
    model.train()
    out  = model(imgs)
    loss = crit(out, lbls)
    loss.backward()
    opt.step()
    preds = torch.argmax(out, dim=1)
    acc   = compute_accuracy(preds.cpu(), lbls.cpu())
    assert loss.item() >= 0.0
    assert 0.0 <= acc <= 1.0

@pytest.mark.parametrize(
    "batch_size, learning_rate, epochs",
    [(32, 0.001, 1), (64, 0.01, 1), (128, 0.0001, 1)]
)
def test_training(batch_size, learning_rate, epochs, data_root, tmp_path):
    config["epochs"]     = epochs
    config["batch_size"] = batch_size
    config["learning_rate"] = learning_rate
    os.environ["CIFAR10_W"] = data_root

    prepare_data(data_root)

    train_main()

    assert os.path.isfile("model.pt")
    assert os.path.isfile("run_id.txt")

    metrics_main()
    assert os.path.isfile("final_metrics.json")
    metrics = json.load(open("final_metrics.json"))
    assert "accuracy" in metrics, "acuracy don't found"
    assert 0.0 <= metrics["accuracy"] <= 1.0, f"Invalid accuracy: {metrics['accuracy']}"

EOF


In [23]:
!pip install pytest-cov



In [74]:
!pytest --maxfail=1 --disable-warnings -q --cov=. --cov-report=term-missing

[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[33ms[0m[32m.[0m[32m.[0m[32m.[0m[33m                                                             [100%][0m
_______________ coverage: platform linux, python 3.11.12-final-0 _______________

Name                 Stmts   Miss  Cover   Missing
--------------------------------------------------
compute_metrics.py      42      3    93%   53-55
hparams.py               1      0   100%
prepare_data.py          6      1    83%   9
test_basic.py           26      0   100%
test_pipeline.py        51      0   100%
train.py                60      1    98%   85
--------------------------------------------------
TOTAL                  186      5    97%
