From my github: https://github.com/smartdanny/imagenette_starter

In [None]:
import matplotlib.pyplot as plt

In [None]:
import torch
import torchvision
import os
import tarfile
import hashlib

# https://github.com/fastai/imagenette

#choose image sizes:
datasets = {
    'full_sz': 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz', # 1.5GB
    '320px': 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz', # 326mb
    '160px': 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz' # 94mb
}


dataset_url = datasets['full_sz']

dataset_filename = dataset_url.split('/')[-1]
dataset_foldername = dataset_filename.split('.')[0]
data_path = '../imagenette_data'
dataset_filepath = os.path.join(data_path,dataset_filename)
dataset_folderpath = os.path.join(data_path,dataset_foldername)

os.makedirs(data_path, exist_ok=True)

download = False
if not os.path.exists(dataset_filepath):
    download = True
else:
    md5_hash = hashlib.md5()


    file = open(dataset_filepath, "rb")

    content = file.read()

    md5_hash.update(content)


    digest = md5_hash.hexdigest()
    if digest != 'fe2fc210e6bb7c5664d602c3cd71e612':
        download = True
if download:
    torchvision.datasets.utils.download(dataset_url, data_path)

with tarfile.open(dataset_filepath, 'r:gz') as tar:
    tar.extractall(path=data_path)
    

In [None]:
TRAIN_DIR = os.path.join(dataset_folderpath,'train')
TEST_DIR = os.path.join(dataset_folderpath,'val')

In [None]:
from src.data_stuff import dataset_tools

In [None]:
rgb_mean = (0.4914, 0.4822, 0.4465)
rgb_std = (0.2023, 0.1994, 0.2010)
train_tfms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    # torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(rgb_mean, rgb_std),
])
test_tfms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(rgb_mean, rgb_std),
])


# dataset_train = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'train'), train_tfms)
# dataset_valid = torchvision.datasets.ImageFolder(os.path.join(dataset_folderpath,'val'), test_tfms)
dataset_train = dataset_tools.ImageFolderWithPaths(TRAIN_DIR, train_tfms)
dataset_valid = dataset_tools.ImageFolderWithPaths(TEST_DIR, test_tfms)

batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        num_workers=6,
        drop_last=True,
        shuffle=True,
)
val_dataloader = torch.utils.data.DataLoader(
        dataset_valid,
        batch_size=batch_size,
        num_workers=6,
        drop_last=True,
        shuffle=True,
)

In [None]:
images = next(iter(train_dataloader))[1]
plt.imshow(torchvision.utils.make_grid(images, padding=20).permute(1, 2, 0))

# Model

In [None]:
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import torchvision.models as models

# MY local imports
from src.callback_stuff import LogConfusionMatrix
from src.model_stuff import MySwinTransformer

In [None]:
EXP_NAME = "imagenette_SwinT"
logger = TensorBoardLogger("lightning_logs", name=EXP_NAME)

In [None]:
model = MySwinTransformer.MySwinTransformer(num_classes=10)
trainer = Trainer(gpus=1,
                  max_epochs=15, 
                  callbacks=[
                      LogConfusionMatrix.LogConfusionMatrix(class_to_idx=dataset_train.class_to_idx)],
                 )

In [None]:
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)