In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from kymatio.torch import Scattering2D
from tqdm import tqdm

In [2]:
args = {
    'learning_rate': 1e-3,
    'batch_size': 256,
    'num_worker': 32,
    'random_seed': 8771795,
    'augmentation': False,
    'num_epoch': 10,
    'device': 'cuda'
}

In [3]:
# Set random seed
torch.random.manual_seed(args['random_seed'])

# Define transformation
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_valid_transform = test_transform
if args['augmentation']:
    train_valid_transform = transforms.Compose([
        transforms.RandomResizedCrop((28,28)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.5,), (0.5,))
    ])

# Load dataset
require_download = os.path.exists('./dataset')
train_valid_dataset = torchvision.datasets.FashionMNIST('./dataset', train=True, transform=train_valid_transform, download=True)
test_dataset = torchvision.datasets.FashionMNIST('./dataset', train=False, transform=test_transform, download=True)

# Split train and validation
torch.random.manual_seed(args['random_seed'])
train_dataset, valid_dataset = torch.utils.data.random_split(train_valid_dataset, [54000, 6000])

# Generate dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True, num_workers=args['num_worker'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args['batch_size'], shuffle=False, num_workers=args['num_worker'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False, num_workers=args['num_worker'])

In [4]:
if not os.path.exists('./features'):
    os.mkdir('./features')
    
for j in range(1,5):
    train_feats, valid_feats, test_feats = [], [], []
    model = Scattering2D(J=j, shape=(28, 28)).to(args['device'])

    # Train
    for x, y in tqdm(train_loader):
        x, y = x.to(args['device']), y.to(args['device'])
        yp = model(x) 
        train_feats.append(yp)
    train_feats = torch.cat(train_feats, dim=0)

    # Valid
    for x, y in tqdm(valid_loader):
        x, y = x.to(args['device']), y.to(args['device'])
        yp = model(x)
        valid_feats.append(yp)
    valid_feats = torch.cat(valid_feats, dim=0)

    # Test
    for x, y in tqdm(test_loader):
        x, y = x.to(args['device']), y.to(args['device'])
        yp = model(x)
        test_feats.append(yp)
    test_feats = torch.cat(test_feats, dim=0)
    
    torch.save((train_feats, valid_feats, test_feats), open(f'./features/scaterring_J{j}.pt', 'wb'))

100%|██████████| 211/211 [00:01<00:00, 135.97it/s]
100%|██████████| 24/24 [00:00<00:00, 33.30it/s]
100%|██████████| 40/40 [00:00<00:00, 48.68it/s]
100%|██████████| 211/211 [00:05<00:00, 36.72it/s]
100%|██████████| 24/24 [00:01<00:00, 19.89it/s]
100%|██████████| 40/40 [00:01<00:00, 25.39it/s]
100%|██████████| 211/211 [00:13<00:00, 15.75it/s]
100%|██████████| 24/24 [00:02<00:00, 11.65it/s]
100%|██████████| 40/40 [00:03<00:00, 13.24it/s]
100%|██████████| 211/211 [00:24<00:00,  8.45it/s]
100%|██████████| 24/24 [00:03<00:00,  7.08it/s]
100%|██████████| 40/40 [00:05<00:00,  7.63it/s]
