In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0    626      0  0:00:08  0:00:08 --:--:--  1185
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly-cp37-cp37m-linux_x86_64.whl...
\ [1 files][109.0 MiB/109.0 MiB]  108.8 MiB/s                                   
Operation completed over 1 objects/109.0 MiB.                                    
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp37-cp37m-linux_x86_64.whl...
\ [1 files][124.6 MiB/124.6 MiB]  112.2 MiB/s                       

In [2]:
!export XLA_USE_BF16=1

In [3]:
!pip install efficientnet_pytorch

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-py3-none-any.whl size=12419 sha256=8ed31bd74ee8ae260075b3353c38ca5a356f93d91fc51a730fbb31b518baa96b
  Stored in directory: /root/.cache/pip/wheels/90/6b/0c/f0ad36d00310e65390b0d4c9218ae6250ac579c92540c9097a
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.6.3
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


## Import Packages

In [4]:
import gc
import os
import numpy as np
import pandas as pd
from PIL import Image, ImageFile

import torch
import torch.nn as nn
from torch.nn import functional as F

from sklearn import metrics
from sklearn import model_selection

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

from joblib import Parallel, delayed
import efficientnet_pytorch
import albumentations
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

## Create Kfolds

In [5]:
df = pd.read_csv("../input/jpeg-melanoma-256x256/train.csv")
df["kfold"] = -1
df = df.sample(frac=1).reset_index(drop=True)
y = df.target.values
kf = model_selection.StratifiedKFold(n_splits=8)
for fold_, (train_idx, test_idx) in enumerate(kf.split(X=df, y=y)):
    df.loc[test_idx, "kfold"] = fold_
df.to_csv("train_folds.csv", index=False)

## Dataloader

In [6]:
class ClassificationDataset:
    def __init__(self, image_paths, targets, resize, augmentations=None):
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize
        self.augmentations = augmentations

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        image = Image.open(self.image_paths[item])
        targets = self.targets[item]
        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )
        image = np.array(image)
        targets = np.array(targets)
        if self.augmentations is not None:
            augmented = self.augmentations(image=image)
            image = augmented["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            "image": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

## Model

In [7]:
class EfficientNet(nn.Module):
    def __init__(self):
        super(EfficientNet, self).__init__()
        self.base_model = efficientnet_pytorch.EfficientNet.from_pretrained(
            'efficientnet-b7'
        )
        self.base_model._fc = nn.Linear(
            in_features=2560,
            out_features=1,
            bias=True
        )

    def forward(self, image, targets):
        out = self.base_model(image)
        loss = nn.BCEWithLogitsLoss()(out, targets.view(-1, 1).type_as(out))
        return out, loss

## Train function

In [8]:
def train(data_loader=None, model=None, optimizer=None, scheduler=None, device=None):
    model.train()
    para_loader = pl.ParallelLoader(data_loader, [device])
    tk0 = para_loader.per_device_loader(device)
    for b_idx, data in enumerate(tk0):
        for key, value in data.items():
            data[key] = value.to(device)
        optimizer.zero_grad()
        _, loss = model(**data)
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
        scheduler.step(loss)
        return loss.item()

## Eval function

In [9]:
def evaluate(
        data_loader,
        model,
        device,
):
    model.eval()
    with torch.no_grad():
        para_loader = pl.ParallelLoader(data_loader, [device])
        tk0 = para_loader.per_device_loader(device)
        for b_idx, data in enumerate(tk0):
            for key, value in data.items():
                data[key] = value.to(device)
            predictions, loss = model(**data)
            predictions = torch.sigmoid(predictions)
    return loss.item(), predictions


## Train loop

In [10]:
def run(fold):
    training_data_path = "../input/siic-isic-224x224-images/train/"
    df = pd.read_csv("./train_folds.csv")
    device = xm.xla_device()
    epochs = 5
    train_bs = 32
    valid_bs = 16
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    model = EfficientNet()
    model.to(device)
    
    train_aug = albumentations.Compose(
        [
            albumentations.Normalize(
                mean, 
                std, 
                max_pixel_value=255.0, 
                always_apply=True
            ),
            albumentations.ShiftScaleRotate(
                shift_limit=0.0625, 
                scale_limit=0.1, 
                rotate_limit=15
            ),
            albumentations.Flip(p=0.5)
        ]
    )

    valid_aug = albumentations.Compose(
        [
            albumentations.Normalize(
                mean, 
                std, 
                max_pixel_value=255.0,
                always_apply=True
            )
        ]
    )
    
    train_images = df_train.image_name.values.tolist()
    train_images = [
        os.path.join(training_data_path, i + ".png") for i in train_images
    ]
    train_targets = df_train.target.values

    valid_images = df_valid.image_name.values.tolist()
    valid_images = [
        os.path.join(training_data_path, i + ".png") for i in valid_images
    ]
    valid_targets = df_valid.target.values

    train_dataset = ClassificationDataset(
        image_paths=train_images,
        targets=train_targets,
        resize=None,
        augmentations=train_aug
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_bs,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0
    )
    
    valid_dataset = ClassificationDataset(
        image_paths=valid_images,
        targets=valid_targets,
        resize=None,
        augmentations=valid_aug
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
      valid_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
    )
    
    
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=valid_bs,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=0
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        patience=3,
        threshold=0.001,
        mode="min"
    )
    
    for epoch in range(epochs):
        training_loss = train(
            data_loader=train_loader,
            model=model,
            optimizer=optimizer,
            device=device,
            scheduler=scheduler,
        )
        
        valid_loss, predictions = evaluate(
            valid_loader,
            model,
            device,
        
        )
        
        xm.master_print(f"Epoch = {epoch}, LOSS = {valid_loss}")
        gc.collect()
    

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run(0)
    

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')


Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=266860719.0), HTML(value='')))



Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7



Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7

Loaded pretrained weights for efficientnet-b7


Loaded pretrained weights for efficientnet-b7
Loaded pretrained weights for efficientnet-b7
Epoch = 0, LOSS = 0.7496073246002197
