In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.optim import Optimizer

def train_one_epoch(model: nn.Module, optimizer: Optimizer, data_loader: DataLoader, device, epoch):
    model.train()
    model.zero_grad()
    tqdm_dataloader = tqdm(data_loader)
    total_batch = 0
    for targets in tqdm_dataloader:
        images = targets['image'].to(device)
        labels = targets['label'].to(device)
        total_batch += len(images)

        logits = model(images)
        loss = nn.CrossEntropyLoss()(logits, labels)
        
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tqdm_dataloader.set_description(
            f"Epoch {epoch + 1}, lr is {optimizer.param_groups[0]['lr']:.6f} loss {loss.item():.3f}")

In [2]:
from PIL import Image
from torchvision import transforms

from torch.utils.data import Dataset

class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)


class TrainDataset_01(Dataset):
    def __init__(self, img_paths, labels, transform = None):
        self.img_paths = img_paths
        self.labels = labels
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((512, 384), Image.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ])
        else:
            self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)
        return {'image': image, 'label': label}
    
    def __len__(self):
        return len(self.img_paths)

In [4]:
import os
import pandas as pd

In [5]:
EPOCH = 10
lr = 0.001
device = torch.device('cuda')

TRAIN_DIR = '/opt/ml/input/data/train'
train_info = pd.read_csv(os.path.join(TRAIN_DIR, 'train_info_01.csv'))

image_paths = train_info['path'].values
labels = train_info['label'].values

In [6]:
import gc
gc.collect()
torch.cuda.empty_cache()
