In [1]:
from __future__ import print_function 
from __future__ import division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import time
import os
import copy

from PIL import ImageFile

In [2]:
data_directory = '/red/ruogu.fang/share/emotion_adversarial_attack/data/processed/EmoSet-118K/'
model_name = 'WSCNet'
epochs = 100

In [3]:
TRAIN_DATA_PATH = os.path.join(data_directory, 'train')
VAL_DATA_PATH = os.path.join(data_directory, 'val')
TEST_DATA_PATH = os.path.join(data_directory, 'test')

BATCH_SIZE = 10
LEARNING_RATE = 0.003

TRANSFORM_IMG = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])


train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)

val_data = torchvision.datasets.ImageFolder(root=VAL_DATA_PATH, transform=TRANSFORM_IMG)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)

test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH, transform=TRANSFORM_IMG)
test_data_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)

In [4]:
from model import WSCNet, CAERSNet, PDANet, Stimuli_Aware_VEA

if model_name == 'WSCNet':
    model = WSCNet()
elif model_name == 'CAERSNet':
    model = CAERSNet()
elif model_name=='PDANet':
    model = PDANet()
elif model_name=='Stimuli_Aware_VEA':
    model = Stimuli_Aware_VEA()

Stimuli_Aware_VEA(
  (fcn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

In [6]:
best_loss = 0
best_model = None
best_metric_epoch = 0

for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_data_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
        
    # validation mode
    if (epoch + 1) % val_interval == 0:
        model.eval()
        running_loss_val = 0
        with torch.no_grad():
            for val_step, val_data in enumerate(val_data_loader):
                val_images, val_labels = (
                    val_data[0].to(device, non_blocking=True),
                    val_data[1].to(device, non_blocking=True),
                )
                
                outputs = model(val_images)
                val_loss = criterion(outputs, labels) 
                running_loss_val += val_loss
            

            if epoch == 0:
                best_loss = running_loss_val
                best_model = model
                best_metric_epoch = epoch

                    
                torch.save(best_model.state_dict(),
                                    os.path.join(args.logdir,
                                                model_name
                                                + str(best_metric_epoch + 1)
                                                + f"loss{epoch_loss:.2f}"
                                                + f"r2{best_metric:.2f}"
                                                + '.pth'))

            elif epoch != 0:
                if running_loss_val < best_loss:  # val_loss < best_loss: # YY I think this should use best_metric instead of loss to save the best model
                    best_loss = running_loss_val
                    best_model = model
                    best_metric_epoch = epoch
                    if dist.get_rank() == 0:
                        print(f"best_result={best_metric}, best model has been updated")

                        torch.save(best_model.module.state_dict(),
                                    os.path.join(args.logdir,
                                                    model_name
                                                + str(best_metric_epoch + 1)
                                                + f"loss{epoch_loss:.2f}"
                                                + f"r2{best_metric:.2f}"
                                                + '.pth'))

            print(
                f"current epoch: {epoch + 1}, current MSE: {epoch_val_loss}",
                f" best MSE: {best_loss}",
                f" at epoch: {best_metric_epoch + 1}"
            )

best_model_wts = copy.deepcopy(best_model.module.state_dict())
print('Finished Training')

UnidentifiedImageError: Caught UnidentifiedImageError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 178, in __getitem__
    sample = self.loader(path)
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 215, in default_loader
    return pil_loader(path)
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 196, in pil_loader
    img = Image.open(f)
  File "/apps/pytorch/1.8.1/lib/python3.9/site-packages/PIL/Image.py", line 2967, in open
    raise UnidentifiedImageError(
PIL.UnidentifiedImageError: cannot identify image file <_io.BufferedReader name='/red/ruogu.fang/share/emotion_adversarial_attack/data/processed/EmoSet-118K/train/Sadness/sadness_06007.jpg'>
