In [1]:
from fastai2.vision.all import *
from nbdev.showdoc import *
import glob
import albumentations
from torchvision import models
from albumentations.pytorch.transforms import ToTensorV2
set_seed(2)

In [2]:
bs = 4

In [3]:
path = untar_data(URLs.PETS); path

Path('/home/ubuntu/.fastai/data/oxford-iiit-pet')

In [4]:
(path/'images').ls()

(#7381) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

## `Dataset`

In [5]:
class PetsDataset:
    def __init__(self, paths, transforms=None):
        self.image_paths = paths
        self.transforms = transforms
        
    def __len__(self): 
        return len(self.image_paths)
    
    def setup(self, pat=r'(.+)_\d+.jpg$', label2int=None):
        "adds a label dictionary to `self`"
        self.pat = re.compile(pat)
        if label2int is not None:
            self.label2int = label2int
            self.int2label = {v:i for i,v in self.label2int.items()}
        else:
            labels = [os.path.basename(self.pat.search(str(p)).group(1))
                  for p in self.image_paths]
            self.labels = set(labels)
            self.label2int = {label:i for i,label in enumerate(self.labels)}
            self.int2label = {v:i for i,v in self.label2int.items()}

    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path)
        img = np.array(img)
        
        target = os.path.basename(self.pat.search(str(img_path)).group(1))
        target = self.label2int[target]
        
        if self.transforms:
            img = self.transforms(image=img)['image']      
            
        return img, torch.tensor(target, dtype=torch.long)

In [6]:
image_paths = get_image_files(path/'images')
image_paths

(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

In [7]:
from tqdm.notebook import tqdm
run_remove = False
def remove(o):
    img = Image.open(o)
    img = np.array(img)
    if img.shape[2] != 3:
        os.remove(o)
if run_remove:
    for o in tqdm(image_paths): remove(o)

In [8]:
image_paths = get_image_files(path/'images')
image_paths

(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

In [9]:
sz = 224
tfms = albumentations.Compose([
    albumentations.Resize(sz, sz) if sz else albumentations.NoOp(),
    albumentations.OneOf(
        [albumentations.Cutout(random.randint(1,8), 16, 16),
         albumentations.CoarseDropout(random.randint(1,8), 16, 16)]
    ),
    albumentations.Normalize(always_apply=True),
    ToTensorV2()
])

In [10]:
dataset = PetsDataset(image_paths, tfms)

In [11]:
dataset.setup()

In [12]:
dataset[0]

(tensor([[[ 0.8618,  0.1597,  0.4166,  ..., -0.6452, -0.3198, -0.2171],
          [ 1.1872,  0.3481,  0.4166,  ..., -0.3027,  0.0912,  0.3138],
          [ 0.8104,  0.6049,  0.0227,  ..., -0.3712, -0.1657, -0.1828],
          ...,
          [ 1.2385,  0.4851,  0.0227,  ...,  0.8789,  1.2214,  0.8961],
          [ 0.7077,  0.9474, -0.6965,  ...,  0.1254,  1.5297,  1.6667],
          [ 0.1083, -0.0801,  0.3652,  ...,  0.2111,  0.5193,  0.6734]],
 
         [[ 0.9230,  0.4328,  0.4503,  ..., -0.2850, -0.0224, -0.0399],
          [ 1.3256,  0.7304,  0.4678,  ..., -0.0399,  0.1527,  0.3277],
          [ 0.8354,  0.8354,  0.3102,  ..., -0.2500, -0.1975, -0.3200],
          ...,
          [ 1.3606,  1.3431,  0.6078,  ...,  0.9755,  1.3957,  1.1331],
          [ 0.7654,  1.0455, -0.0574,  ...,  0.7654,  1.6232,  1.7458],
          [ 0.4153,  0.5903,  0.9230,  ...,  0.7654,  0.8529,  1.0980]],
 
         [[ 0.3393, -0.3578, -0.4275,  ..., -0.7936, -0.4624, -0.3578],
          [ 0.6531, -0.2358,

In [13]:
dataset[0][0].shape

torch.Size([3, 224, 224])

## `DataLoaders`

In [14]:
nval = int(len(image_paths)*0.2)
nval

1475

In [15]:
trn_img_paths = image_paths[:-nval]
val_img_paths = image_paths[-nval:]
assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)
len(trn_img_paths), len(val_img_paths)

(5903, 1475)

In [16]:
trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)
val_dataset = PetsDataset(val_img_paths, transforms=tfms)

In [17]:
trn_dataset.setup(label2int=dataset.label2int)
val_dataset.setup(label2int=dataset.label2int)

In [18]:
# bs = 64

In [19]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)

In [20]:
next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape

(torch.Size([4, 3, 224, 224]), torch.Size([4, 3, 224, 224]))

## `Model`

In [21]:
resnet34_bn = models.resnet34(num_classes=len(trn_dataset.label2int))
resnet34_bn

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [22]:
class GroupNorm_32(torch.nn.GroupNorm):
    def __init__(self, num_channels, num_groups=32, **kwargs):
        super().__init__(num_groups, num_channels, **kwargs)

In [23]:
resnet34_gn = models.resnet34(norm_layer=GroupNorm_32, num_classes=len(trn_dataset.label2int))
resnet34_gn

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm_32(32, 64, eps=1e-05, affi

In [24]:
resnet34_gn(next(iter(trn_loader))[0]).shape

torch.Size([4, 37])

## Training using `PytorchLightning`

In [25]:
from pytorch_lightning import LightningModule, Trainer

In [26]:
class Model(LightningModule):
    def __init__(self, base):
        super().__init__()
        self.base = base

    def forward(self, x):
        return self.base(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def step(self, batch):
        x, y  = batch
        y_hat = self(x)
        loss  = nn.CrossEntropyLoss()(y_hat, y)
        return loss, y, y_hat

    def training_step(self, batch, batch_nb):
        loss, _, _ = self.step(batch)
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        loss, y, y_hat = self.step(batch)
        return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        acc = self.get_accuracy(outputs)
        print(f"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}")
        return {'loss': avg_loss}
    
    def get_accuracy(self, outputs):
        from sklearn.metrics import accuracy_score
        y = torch.cat([x['y'] for x in outputs])
        y_hat = torch.cat([x['y_hat'] for x in outputs])
        preds = y_hat.argmax(1)
        return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())

In [27]:
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)

In [28]:
debug = False
gpus = torch.cuda.device_count()
trainer = Trainer(gpus=gpus, max_epochs=50, 
                  num_sanity_val_steps=1 if debug else 0)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [29]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:0 | Loss:3.8165385723114014 | Accuracy:0.05694915254237288


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:1 | Loss:3.4834210872650146 | Accuracy:0.07186440677966102


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:2 | Loss:3.6691737174987793 | Accuracy:0.07254237288135593


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:3 | Loss:3.4139177799224854 | Accuracy:0.08


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:4 | Loss:3.2902932167053223 | Accuracy:0.09627118644067796


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:5 | Loss:3.308424711227417 | Accuracy:0.10847457627118644


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:6 | Loss:3.1262896060943604 | Accuracy:0.13559322033898305


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:7 | Loss:3.1067755222320557 | Accuracy:0.14576271186440679


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:8 | Loss:3.039379596710205 | Accuracy:0.14983050847457627


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:9 | Loss:2.8480300903320312 | Accuracy:0.18508474576271186


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:10 | Loss:2.803644895553589 | Accuracy:0.20813559322033898


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:11 | Loss:2.7595109939575195 | Accuracy:0.23457627118644067


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:12 | Loss:2.5479202270507812 | Accuracy:0.28271186440677964


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:13 | Loss:2.571686029434204 | Accuracy:0.2983050847457627


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:14 | Loss:2.500197649002075 | Accuracy:0.3077966101694915


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:15 | Loss:2.504734992980957 | Accuracy:0.3233898305084746


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:16 | Loss:2.4901981353759766 | Accuracy:0.3396610169491525


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:17 | Loss:2.4564080238342285 | Accuracy:0.36542372881355933


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:18 | Loss:2.6735761165618896 | Accuracy:0.34983050847457625


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:19 | Loss:2.7455430030822754 | Accuracy:0.36


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:20 | Loss:3.1568517684936523 | Accuracy:0.3376271186440678


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:21 | Loss:3.229506254196167 | Accuracy:0.3464406779661017


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:22 | Loss:3.116837501525879 | Accuracy:0.3559322033898305


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:23 | Loss:3.3821210861206055 | Accuracy:0.34983050847457625


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:24 | Loss:3.667248249053955 | Accuracy:0.328135593220339



1

In [30]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:0 | Loss:3.604952096939087 | Accuracy:0.031186440677966103


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:1 | Loss:3.51139760017395 | Accuracy:0.04745762711864407


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:2 | Loss:3.5002284049987793 | Accuracy:0.06305084745762712


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:3 | Loss:3.4050915241241455 | Accuracy:0.0711864406779661


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:4 | Loss:3.424589157104492 | Accuracy:0.07389830508474576


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:5 | Loss:3.395509958267212 | Accuracy:0.07254237288135593


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:6 | Loss:3.3599977493286133 | Accuracy:0.08542372881355932


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:7 | Loss:3.3448259830474854 | Accuracy:0.0759322033898305


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:8 | Loss:3.2987756729125977 | Accuracy:0.08542372881355932


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:9 | Loss:3.3236124515533447 | Accuracy:0.09423728813559322


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:10 | Loss:3.2912700176239014 | Accuracy:0.09830508474576272


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:11 | Loss:3.2386019229888916 | Accuracy:0.11322033898305085


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:12 | Loss:3.2127442359924316 | Accuracy:0.10847457627118644


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:13 | Loss:3.282404661178589 | Accuracy:0.10915254237288136


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:14 | Loss:3.2815961837768555 | Accuracy:0.12271186440677966


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:15 | Loss:3.3165860176086426 | Accuracy:0.1423728813559322


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:16 | Loss:3.4761595726013184 | Accuracy:0.13288135593220338


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:17 | Loss:3.692209482192993 | Accuracy:0.12406779661016949


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:18 | Loss:4.045893669128418 | Accuracy:0.13898305084745763


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:19 | Loss:4.537643909454346 | Accuracy:0.14508474576271185


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:20 | Loss:5.022345066070557 | Accuracy:0.13966101694915253


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:21 | Loss:5.505904674530029 | Accuracy:0.14779661016949153


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:22 | Loss:5.6631903648376465 | Accuracy:0.14915254237288136


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:23 | Loss:5.990276336669922 | Accuracy:0.12271186440677966


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:24 | Loss:6.015255928039551 | Accuracy:0.13559322033898305



1

In [31]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=64, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=4, shuffle=False)

In [32]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:0 | Loss:3.54060697555542 | Accuracy:0.3735593220338983


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:1 | Loss:3.6123523712158203 | Accuracy:0.3776271186440678


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:2 | Loss:3.6704928874969482 | Accuracy:0.3905084745762712


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:3 | Loss:3.7412772178649902 | Accuracy:0.3898305084745763


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:4 | Loss:3.6899402141571045 | Accuracy:0.3891525423728814


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:5 | Loss:3.7951247692108154 | Accuracy:0.3959322033898305


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:6 | Loss:3.8599658012390137 | Accuracy:0.3898305084745763


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:7 | Loss:3.8910741806030273 | Accuracy:0.39389830508474577


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:8 | Loss:4.046212196350098 | Accuracy:0.3871186440677966


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:9 | Loss:4.152932167053223 | Accuracy:0.37966101694915255


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:10 | Loss:4.072637557983398 | Accuracy:0.38305084745762713


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:11 | Loss:4.0407257080078125 | Accuracy:0.39322033898305087


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:12 | Loss:4.077968597412109 | Accuracy:0.38169491525423727


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:13 | Loss:4.097992897033691 | Accuracy:0.3911864406779661


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:14 | Loss:4.074225425720215 | Accuracy:0.3945762711864407


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:15 | Loss:4.119948863983154 | Accuracy:0.3905084745762712


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:16 | Loss:4.232583045959473 | Accuracy:0.38508474576271184


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:17 | Loss:4.135111331939697 | Accuracy:0.3905084745762712


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:18 | Loss:5.55678653717041 | Accuracy:0.288135593220339


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:19 | Loss:4.31779670715332 | Accuracy:0.3708474576271186


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:20 | Loss:3.84808349609375 | Accuracy:0.38033898305084746


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:21 | Loss:3.9929111003875732 | Accuracy:0.39796610169491525


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:22 | Loss:3.9850997924804688 | Accuracy:0.3891525423728814


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:23 | Loss:4.030226230621338 | Accuracy:0.3905084745762712


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:24 | Loss:4.062991142272949 | Accuracy:0.39864406779661016



1

In [33]:
trainer = Trainer(gpus=gpus, max_epochs=25, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:0 | Loss:6.449962615966797 | Accuracy:0.1464406779661017


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:1 | Loss:6.826666831970215 | Accuracy:0.14101694915254237


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:2 | Loss:6.971066474914551 | Accuracy:0.14983050847457627


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:3 | Loss:7.010488510131836 | Accuracy:0.15457627118644068


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:4 | Loss:7.043515205383301 | Accuracy:0.1464406779661017


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:5 | Loss:6.943863868713379 | Accuracy:0.14372881355932204


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:6 | Loss:7.146998405456543 | Accuracy:0.15389830508474575


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:7 | Loss:7.152970790863037 | Accuracy:0.14169491525423727


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:8 | Loss:7.2448859214782715 | Accuracy:0.14101694915254237


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:9 | Loss:7.3982672691345215 | Accuracy:0.1369491525423729


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:10 | Loss:7.3345160484313965 | Accuracy:0.14508474576271185


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:11 | Loss:7.635211944580078 | Accuracy:0.14576271186440679


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:12 | Loss:7.198373317718506 | Accuracy:0.1335593220338983


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:13 | Loss:7.670504570007324 | Accuracy:0.13627118644067795


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:14 | Loss:7.509610652923584 | Accuracy:0.14372881355932204


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:15 | Loss:7.9462714195251465 | Accuracy:0.14033898305084747


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:16 | Loss:7.655241966247559 | Accuracy:0.1369491525423729


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:17 | Loss:7.840564727783203 | Accuracy:0.1288135593220339


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:18 | Loss:7.267488479614258 | Accuracy:0.1430508474576271


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:19 | Loss:7.464041709899902 | Accuracy:0.14101694915254237


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:20 | Loss:7.323596000671387 | Accuracy:0.1423728813559322


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:21 | Loss:7.24222469329834 | Accuracy:0.14847457627118643


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:22 | Loss:7.451413631439209 | Accuracy:0.14508474576271185


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:23 | Loss:7.660017967224121 | Accuracy:0.14101694915254237


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:24 | Loss:7.5283966064453125 | Accuracy:0.14033898305084747



1