In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [None]:
args = {
    'learning_rate': 1e-3,
    'batch_size': 8,
    'num_worker': 8,
    'random_seed': 8771795,
    'augmentation': True,
    'num_epoch': 10,
    'device': 'cuda'
}

In [None]:
# Set random seed
torch.random.manual_seed(args['random_seed'])

# Define transformation
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_valid_transform = test_transform
if args['augmentation']:
    train_valid_transform = transforms.Compose([
        transforms.RandomResizedCrop((28,28)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.5,), (0.5,))
    ])

# Load dataset
require_download = os.path.exists('./dataset')
train_valid_dataset = torchvision.datasets.FashionMNIST('./dataset', train=True, transform=train_valid_transform, download=True)
test_dataset = torchvision.datasets.FashionMNIST('./dataset', train=False, transform=test_transform, download=True)

# Split train and validation
torch.random.manual_seed(args['random_seed'])
train_dataset, valid_dataset = torch.utils.data.random_split(train_valid_dataset, [54000, 6000])

# Generate dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True, num_workers=args['num_worker'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args['batch_size'], shuffle=False, num_workers=args['num_worker'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False, num_workers=args['num_worker'])

In [None]:
model = torchvision.models.resnext50_32x4d(pretrained=True).to(args['device'])
criterion = nn.BCELoss().to(args['device'])
optimizer = optim.Adam(model.parameters(), lr=args['learning_rate'])

best_metric = 0.0
best_state_dict = model.state_dict()
for i in range(args['num_epoch']):
    # Train
    model.train()
    torch.set_grad_enabled(False)
    for x, y in tqdm(train_loader):
        x, y = x.to(args['device']), y.to(args['device'])
        ... # TODO        
    
    # Valid
    model.eval()
    torch.set_grad_enabled(False)
    for x, y in tqdm(valid_loader):
        x, y = x.to(args['device']), y.to(args['device'])
        ... # TODO        
        valid_metric = 1.0 # TODO
    
    # Choose best Validation Metric
    if best_metric <= valid_metric:
        best_metric = valid_metric
        best_state_dict = model.state_dict()        
    
# Test
model.load_state_dict(best_state_dict)
for x, y in tqdm(test_loader):
    x, y = x.to(args['device']), y.to(args['device'])
    ... # TODO        
    test_metric = 1.0 # TODO