In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    transforms.RandomHorizontalFlip(),
    # transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    transforms.RandomAffine(
        degrees=0,
        translate=(0.05, 0.05),
        shear=5,
        scale=(0.8, 1.2)
    ),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
# RNN Model
class RecurrentBlock(nn.Module):
    def __init__(self):
        super(RecurrentBlock, self).__init__()
        self.rnn = nn.LSTM(96, 128, batch_first=True)
        self.fc = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(x.size(0), 32, 96)
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])
        return out

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

model = RecurrentBlock().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training
for epoch in range(5):
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')

Using device: cuda
Epoch 1, Loss: 1.8802
Epoch 2, Loss: 1.7226
Epoch 3, Loss: 1.5773
Epoch 4, Loss: 1.5998
Epoch 5, Loss: 1.4970
Accuracy: 44.63%


Using device: cuda
Epoch 1, Loss: 1.8280
Epoch 2, Loss: 1.7222
Epoch 3, Loss: 1.6249
Epoch 4, Loss: 1.4987
Epoch 5, Loss: 1.3508
Accuracy: 46.58%