In [1]:
import os
import torch
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from loss import DistillationLoss
from resnet import resnet18
from classifier import ArgMaxClassifier
from training import DistillationTrainer

config={
    "lr": 1e-1,
    "momentum": 0,
    "weight_decay": 0,
    "lr_sched_factor": 1e-1,
    "lr_sched_patience": 2,
    "early_stopping": None,
    "checkpoint_file": 'checkpoints/confidence',
    "data_dir": os.path.expanduser('~/.pytorch-datasets'),
    "seed": 403,
    "data_mean": (0.4914, 0.4822, 0.4465),
    "data_std": (0.2471, 0.2435, 0.2616),
    "train_batch_size": 128,
    "test_batch_size": 32, 
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = config.seed
torch.manual_seed(seed)


cifar10_labels = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

mean = config.data_mean
std =  config.data_std
train_batch_size = config.train_batch_size
test_batch_size = config.test_batch_size

transform = T.Compose(
            [
                T.RandomCrop(32, padding=4),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(mean, std),
            ]
        )

[34m[1mwandb[0m: Currently logged in as: [33mtom-rahav[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
# Define datasets
ds_train = torchvision.datasets.CIFAR10(root=config.data_dir, download=True, train=True, transform=transform)
ds_test = torchvision.datasets.CIFAR10(root=config.data_dir, download=True, train=False, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Define dataloaders

dl_train = torch.utils.data.DataLoader(ds_train, train_batch_size, shuffle=False)
dl_test = torch.utils.data.DataLoader(ds_test, test_batch_size, shuffle=False)

In [4]:
# Define models
# Untrained model (student)
student_model = resnet18()
student_classifier = ArgMaxClassifier(student_model)
# Pretrained model (teacher)
teacher_model = resnet18(pretrained=True)
teacher_classifier = ArgMaxClassifier(teacher_model)


Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/cpu.py", line 166, in start
    self.metrics_monitor.start()
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'


Loaded a state dict from: /home/tom.rahav/confidence_project/state_dicts/resnet18.pt


[]

In [5]:
# Define optimizer and schedular
optimizer = torch.optim.SGD(params=student_classifier.parameters(), 
                            lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=config.lr_sched_factor, patience=config.lr_sched_patience)

In [6]:
# Define loss function
loss_fn = DistillationLoss()

In [7]:
# Define trainer
trainer = DistillationTrainer(student_classifier, teacher_classifier, loss_fn, optimizer, device)

Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/cpu.py", line 166, in start
    self.metrics_monitor.start()
  File "/home/tom.rahav/miniconda3/envs/deep-hw/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'
Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/home/tom

In [None]:
res = trainer.fit(dl_train, dl_test, run, 100, checkpoints=config.checkpoint_file, early_stopping=config.early_stopping, )

--- EPOCH 1/100 ---


train_batch:   0%|          | 0/391 [00:00<?, ?it/s]

test_batch:   0%|          | 0/313 [00:00<?, ?it/s]


*** Saved checkpoint checkpoints/confidence_agr
--- EPOCH 2/100 ---


train_batch:   0%|          | 0/391 [00:00<?, ?it/s]

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


test_batch:   0%|          | 0/313 [00:00<?, ?it/s]

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)



*** Saved checkpoint checkpoints/confidence_agr
--- EPOCH 3/100 ---


train_batch:   0%|          | 0/391 [00:00<?, ?it/s]

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
