In [1]:
from __future__ import print_function 
from __future__ import division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import time
import os
import copy

from PIL import ImageFile

In [22]:
data_directory = '/red/ruogu.fang/share/emotion_adversarial_attack/data/processed/EmoSet-118K/'
model_name = 'CAERSNet'

In [23]:
from config import WSCNet_Config, CAERSNet_Config, PDANet_Config, Stimuli_Aware_VEA_Config

if model_name == 'WSCNet':
    config = WSCNet_Config()
elif model_name == 'CAERSNet':
    config = CAERSNet_Config()
elif model_name=='PDANet':
    config = PDANet_Config()
elif model_name=='Stimuli_Aware_VEA':
    config = Stimuli_Aware_VEA_Config()
    
model = config.model

In [24]:
config.epoch

100

In [25]:
TRAIN_DATA_PATH = os.path.join(data_directory, 'train')
VAL_DATA_PATH = os.path.join(data_directory, 'val')
TEST_DATA_PATH = os.path.join(data_directory, 'test')

epochs = config.epoch
BATCH_SIZE = config.batch_size
LEARNING_RATE = config.learning_rate

TRANSFORM_IMG = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])


train_data = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=TRANSFORM_IMG)
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=1)

val_data = torchvision.datasets.ImageFolder(root=VAL_DATA_PATH, transform=TRANSFORM_IMG)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=1)

test_data = torchvision.datasets.ImageFolder(root=TEST_DATA_PATH, transform=TRANSFORM_IMG)
test_data_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=1)

In [26]:
print(f"Number of training samples {len(train_data)}")
print(f"Number of validation samples {len(val_data)}")
print(f"Number of testing samples {len(test_data)}")

Number of training samples 94481
Number of validation samples 5905
Number of testing samples 17716


In [27]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
weight = torch.Tensor([1 - 15681/94481, 1 - 8434/94481, 1 - 12022/94481, 1 - 13036/94481, 1 - 8467/94481, 1 - 15837/94481, 1 - 10786/94481, 1- 10218/94481]).to(device)
criterion = nn.CrossEntropyLoss(weight)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
model.to(device)

CAERSNet(
  (two_stream_net): TwoStreamNetwork(
    (face_encoding_module): Encoder(
      (convs): ModuleList(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (bn): ModuleList(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3-4): 2 x BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilat

In [28]:
best_loss = 0
best_model = None
best_metric_epoch = 0
val_interval = 1

for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(val_data_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
    print('epoch:', str(epoch), ' /loss:', str(running_loss))
          
    # validation mode
    if (epoch + 1) % val_interval == 0:
        model.eval()
        running_loss_val = 0
        with torch.no_grad():
            for val_step, val_data in enumerate(val_data_loader):
                val_images, val_labels = (
                    val_data[0].to(device, non_blocking=True),
                    val_data[1].to(device, non_blocking=True),
                )
                
                outputs = model(val_images)
                val_loss = criterion(outputs, val_labels) 
                running_loss_val += val_loss
            

            if epoch == 0:
                best_loss = running_loss_val
                best_model = model
                best_metric_epoch = epoch
 
                #torch.save(best_model.state_dict(),
                #                    os.path.join(args.logdir,
                #                                model_name
                #                                + str(best_metric_epoch + 1)
                #                                + f"loss{epoch_loss:.2f}"
                #                                + f"r2{best_metric:.2f}"
                #                                + '.pth'))

            elif epoch != 0:
                if running_loss_val < best_loss:  # val_loss < best_loss: # YY I think this should use best_metric instead of loss to save the best model
                    best_loss = running_loss_val
                    best_model = model
                    best_metric_epoch = epoch
                    if dist.get_rank() == 0:
                        print(f"best_result={best_metric}, best model has been updated")

                        torch.save(best_model.state_dict(),
                                    os.path.join(args.logdir,
                                                    model_name
                                                + str(best_metric_epoch + 1)
                                                + f"loss{epoch_loss:.2f}"
                                                + f"r2{best_metric:.2f}"
                                                + '.pth'))

            print(
                f"current epoch: {epoch + 1}, current MSE: {running_loss_val}",
                f" best MSE: {best_loss}",
                f" at epoch: {best_metric_epoch + 1}"
            )

best_model_wts = copy.deepcopy(best_model.state_dict())
print('Finished Training')

TypeError: conv2d() received an invalid combination of arguments - got (bool, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!bool!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!bool!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)


In [8]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
#model = SwinforRegression(model_name_or_path='microsoft/swin-large-patch4-window12-384-in22k', n_label=7)
#model.to(device)

#checkpoint = torch.load('/red/ruogu.fang/yyang/ADRiskFactorPrediction/swin/classification/savedmodel/Swin_classification/Swin_classification6loss5.16r20.66.pth')
checkpoint = torch.load('/red/ruogu.fang/leem.s/Emotion-Adversarial-Attack/hpg-code/slurm/logs/WSCNet/WSCNet6loss0.89.pth')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [17]:
train_data.class_to_idx

{'Amusement': 0,
 'Anger': 1,
 'Awe': 2,
 'Contentment': 3,
 'Disgust': 4,
 'Excitement': 5,
 'Fear': 6,
 'Sadness': 7}

In [21]:
from torchmetrics.classification import MulticlassAccuracy
model.eval()

y_true = torch.Tensor().to(device)
y_pred = torch.Tensor().to(device)
valid_accuracy = MulticlassAccuracy(num_classes=8, average=None).to(device)

with torch.no_grad():
    for i, test_data in enumerate(test_data_loader):
        test_images, test_labels = (
            test_data[0].to(device, non_blocking=True),
            test_data[1].to(device, non_blocking=True))
        
        outputs = model(test_images)
        metric = valid_accuracy(torch.tensor(outputs), test_labels.long())
        
        y_true = torch.cat((test_labels.long(), y_true), dim=-1)
        y_pred = torch.cat((torch.tensor(outputs),y_pred), dim=0)

    total_valid_accuracy = valid_accuracy.compute()

print('Overall accuracy is {}'.format(total_valid_accuracy))

  metric = valid_accuracy(torch.tensor(outputs), test_labels.long())
  y_pred = torch.cat((torch.tensor(outputs),y_pred), dim=0)


Overall accuracy is tensor([0.5668, 0.6616, 0.6765, 0.5781, 0.7682, 0.8245, 0.7077, 0.6184],
       device='cuda:0')
