In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder

from tqdm import tqdm,trange
import copy
import torchvision.models as models

In [2]:

train_tfm = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 224)
    transforms.Resize((224)),
    transforms.RandomResizedCrop((224,224), scale=(0.75,1.0), ratio=(0.8,1.25)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),

    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5, scale=(0.02,0.2), ratio=(0.3,3.3), value=0, inplace=False),
])

test_tfm = transforms.Compose([
    transforms.Resize((224)),
    transforms.CenterCrop((224,224)),
    transforms.ToTensor(),
])



In [3]:
batch_size = 15

train_set = DatasetFolder("./hw5_data/train", loader=lambda x: Image.open(x).convert('RGB'), extensions="jpg", transform=train_tfm)

test_set = DatasetFolder("./hw5_data/test", loader=lambda x: Image.open(x).convert('RGB'), extensions="jpg", transform=test_tfm)

print(len(train_set))
print(len(test_set))

# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

1500
150


In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = models.resnext50_32x4d(pretrained=False).to(device)
#model.load_state_dict(torch.load("model.ckpt"))
model.device = device

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=6, cooldown=0)

# The number of training epochs.
n_epochs = 150
early_stop = 20

now_epoch = 0
best_acc = 0

for epoch in range(n_epochs):
    model.train()

    train_loss = []
    train_accs = []
    
    for batch in tqdm(train_loader):

        imgs, labels = batch

        # print(imgs.size())
        # print(labels)
        
        logits = model(imgs.to(device))

        loss = criterion(logits, labels.to(device))

        optimizer.zero_grad()

        loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        optimizer.step()

        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        train_loss.append(loss.item())
        train_accs.append(acc)

    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # ---------- Validation ----------
    model.eval()

    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in test_loader:

        imgs, labels = batch

        with torch.no_grad():
            logits = model(imgs.to(device))

        loss = criterion(logits, labels.to(device))

        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        valid_loss.append(loss.item())
        valid_accs.append(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    now_epoch += 1

    # Print the information.
    print(f"[{epoch + 1:03d}/{n_epochs:03d}] Train | loss = {train_loss:.3f}, acc = {train_acc:.3f}, Valid | loss = {valid_loss:.3f}, acc = {valid_acc:.3f}, lr = {optimizer.param_groups[0]['lr']:.7f}")

    scheduler.step(valid_acc)
    
    if valid_acc > best_acc:
        best_acc = valid_acc
        now_epoch = 0
        if best_acc > 0.7:
            torch.save(model.state_dict(), "model.ckpt")
            print(f"best_model saved, accuracy: {best_acc}")
        else:
            print(f"best_model, accuracy: {best_acc}")

    if now_epoch > early_stop:
        print(f"early stop")
        break
print(f"training end, valid acc:{best_acc}")


100%|██████████| 100/100 [00:40<00:00,  2.46it/s]
  0%|          | 0/100 [00:00<?, ?it/s][001/150] Train | loss = 3.056, acc = 0.125, Valid | loss = 5.320, acc = 0.087, lr = 0.0010000
best_model, accuracy: 0.08666666597127914
100%|██████████| 100/100 [00:40<00:00,  2.49it/s]
  0%|          | 0/100 [00:00<?, ?it/s][002/150] Train | loss = 2.286, acc = 0.249, Valid | loss = 2.404, acc = 0.220, lr = 0.0010000
best_model, accuracy: 0.2200000137090683
100%|██████████| 100/100 [00:40<00:00,  2.47it/s]
  0%|          | 0/100 [00:00<?, ?it/s][003/150] Train | loss = 2.072, acc = 0.309, Valid | loss = 2.393, acc = 0.220, lr = 0.0010000
100%|██████████| 100/100 [00:40<00:00,  2.48it/s]
  0%|          | 0/100 [00:00<?, ?it/s][004/150] Train | loss = 1.917, acc = 0.352, Valid | loss = 3.089, acc = 0.200, lr = 0.0010000
100%|██████████| 100/100 [00:40<00:00,  2.47it/s]
  0%|          | 0/100 [00:00<?, ?it/s][005/150] Train | loss = 1.784, acc = 0.381, Valid | loss = 1.890, acc = 0.400, lr = 0.00100

In [7]:
del model

In [9]:
# Eval mode
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = models.resnext50_32x4d(pretrained=False).to(device)
model.device = device
model.load_state_dict(torch.load("model.ckpt"))
model.eval()

predictions = []
test_accs=[]

# Iterate the testing set by batches.
for batch in tqdm(test_loader):
    
    imgs, labels = batch

    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        logits = model(imgs.to(device))

    acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
    test_accs.append(acc)

    # Take the class with greatest logit as prediction and record it.
    predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

test_acc = sum(test_accs) / len(test_accs)
print(f"test acc: {test_acc:.4f}")

# Save predictions into the file.
with open(f"prediction.csv", "w") as f:

    # The first row must be "Id, Category"
    f.write("Id,Category\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for i, pred in  enumerate(predictions):
        f.write(f"{i},{pred}\n")
    print("prediction saved")

100%|██████████| 10/10 [00:01<00:00,  8.11it/s]test acc: 0.8533
prediction saved

