In [1]:
!pip install scattering-transform
!pip install transformers



In [2]:
import os
import glob
import time
import torch
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision.models as models
from torchvision import transforms, utils
from torchvision.models.feature_extraction import create_feature_extractor

from PIL import Image
from skimage.transform import resize

import warnings
warnings.filterwarnings("ignore")

from scattering_transform import SCL, SCLTrainingWrapper
from transformers import ViTForImageClassification, ViTFeatureExtractor

### Load Dataset

In [3]:
class dataset(Dataset):
    def __init__(self, root_dir, dataset_type, img_size, transform=None, shuffle=False):
        self.root_dir = root_dir
        self.transform = transform
        self.file_names = [f for f in glob.glob(os.path.join(root_dir, "*", "*.npz")) \
                            if dataset_type in f]
        self.img_size = img_size
        self.shuffle = shuffle

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        data_path = self.file_names[idx]
        data = np.load(data_path)
        image = data["image"].reshape(16, 160, 160)
        target = data["target"]
        structure = data["structure"]
        meta_target = data["meta_target"]
        meta_structure = data["meta_structure"]

        if self.shuffle:
            context = image[:8, :, :]
            choices = image[8:, :, :]
            indices = list(range(8))
            np.random.shuffle(indices)
            new_target = indices.index(target)
            new_choices = choices[indices, :, :]
            image = np.concatenate((context, new_choices))
            target = new_target
        
        resize_image = []
        for idx in range(0, 16):
            resize_image.append(resize(image[idx,:,:], (self.img_size, self.img_size)))
        resize_image = np.stack(resize_image)

        embedding = torch.zeros((6, 300), dtype=torch.float)
        indicator = torch.zeros(1, dtype=torch.float)
        element_idx = 0
    
        del data
        if self.transform:
            resize_image = self.transform(resize_image)
            target = torch.tensor(target, dtype=torch.long)
            meta_target = self.transform(meta_target)
            meta_structure = self.transform(meta_structure)
            meta_target = torch.tensor(meta_target, dtype=torch.long)
        return resize_image, target, meta_target, meta_structure, embedding, indicator

#### Experiment Hyperparameters

In [4]:
class Args:
    
    def __init__(self,):
        self.model = 'ViT_SCL'
        self.epochs = 100
        self.batch_size = 32
        self.seed = 12345
        self.device = 1
        self.load_workers = 16
        self.resume = False
        self.path = '/common/home/ab2253/Desktop/data_new'
        self.save = './ckpt_res/'
        self.img_size = 80
        self.lr = 1e-4
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.epsilon = 1e-8
        self.meta_alpha = 0.
        self.meta_beta = 0.
        self.perc_train = 100
        self.verbose = False

        # ViT parameters
        self.vit_requires_grad = False
        self.vec2image_input_dim = 768

        # SCL Hyperparameters
        self.scl_image_size = 224
        self.scl_set_size = 9
        self.scl_conv_channels = [1, 16, 16, 32, 32, 32]
        self.scl_conv_output_dim = 80
        self.scl_attr_heads = 10
        self.scl_attr_net_hidden_dims = [128]
        self.scl_rel_heads = 80
        self.scl_rel_net_hidden_dims = [64, 23, 5]
        
args = Args()

args.cuda = torch.cuda.is_available()
torch.cuda.set_device(args.device)
torch.cuda.manual_seed(args.seed)

if not os.path.exists(args.save):
    os.makedirs(args.save)

#### Load Data

In [5]:
class ToTensor(object):
    def __call__(self, sample):
        return torch.tensor(sample, dtype=torch.float32)

train = dataset(args.path, "train", args.img_size, transform=transforms.Compose([ToTensor()]),shuffle=True)
valid = dataset(args.path, "val", args.img_size, transform=transforms.Compose([ToTensor()]))
test = dataset(args.path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))

subset_indices = np.random.choice(len(train), len(train)*args.perc_train // 100, replace=False)
train_subset = Subset(train, subset_indices)

print("Number of samples in original train set =", len(train))
print("Number of samples in train subset =", len(train_subset))
print("All samples are unique =", len(subset_indices) == len(set(subset_indices)))

trainloader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=True, num_workers=16)
validloader = DataLoader(valid, batch_size=args.batch_size, shuffle=False, num_workers=16)
testloader = DataLoader(test, batch_size=args.batch_size, shuffle=False, num_workers=16)

Number of samples in original train set = 24000
Number of samples in train subset = 24000
All samples are unique = True


### Model

In [6]:
class BasicModel(nn.Module):
    def __init__(self, args):
        super(BasicModel, self).__init__()
        self.name = args.model
    
    def load_model(self, path, epoch):
        state_dict = torch.load(path+'{}_epoch_{}.pth'.format(self.name, epoch))['state_dict']
        self.load_state_dict(state_dict)

    def save_model(self, path, epoch, acc, loss):
        torch.save({'state_dict': self.state_dict(), 'acc': acc, 'loss': loss}, path+'{}_epoch_{}.pth'.format(self.name, epoch))

    def compute_loss(self, output, target, meta_target, meta_structure):
        pass

    def train_(self, image, target, meta_target, meta_structure, embedding, indicator):
        self.optimizer.zero_grad()
        output = self(image, embedding, indicator)
        loss = self.compute_loss(output, target, meta_target, meta_structure)
        loss.backward()
        self.optimizer.step()
        pred = output[0].data.max(1)[1]
        correct = pred.eq(target.data).cpu().sum().numpy()
        accuracy = correct * 100.0 / target.size()[0]
        return loss.item(), accuracy

    def validate_(self, image, target, meta_target, meta_structure, embedding, indicator):
        with torch.no_grad():
            output = self(image, embedding, indicator)
        loss = self.compute_loss(output, target, meta_target, meta_structure)
        pred = output[0].data.max(1)[1]
        correct = pred.eq(target.data).cpu().sum().numpy()
        accuracy = correct * 100.0 / target.size()[0]
        return loss.item(), accuracy

    def test_(self, image, target, meta_target, meta_structure, embedding, indicator):
        with torch.no_grad():
            output = self(image, embedding, indicator)
        pred = output[0].data.max(1)[1]
        correct = pred.eq(target.data).cpu().sum().numpy()
        accuracy = correct * 100.0 / target.size()[0]
        return accuracy

In [7]:
def zeros(shape):
    return nn.init.zeros_(torch.empty(shape))

def glorot(shape):
    return nn.init.xavier_uniform_(torch.empty(shape), gain=1.)

class Vec2Image(nn.Module):

    def __init__(self, input_dim, output_dim, bias=True, act=F.relu):
        super(Vec2Image, self).__init__()

        if len(output_dim) != 3:
            raise ValueError("output_dim must be 3d.")

        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.act = act

        self.weight = nn.Parameter(glorot((input_dim, output_dim[1]*output_dim[2])))

        if bias:
            self.bias = nn.Parameter(zeros((output_dim[1]*output_dim[2])))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # Expected shape of x: (b, n)

        x = self.act(torch.matmul(x, self.weight) + self.bias)
        x = x.view((x.shape[0] // 8, 8, 1, self.output_dim[1], self.output_dim[2]))

        return x


#### Vision Transformer 

In [8]:
TO_IMG = transforms.ToPILImage()

def to_image(b):
    b = b.reshape((b.shape[0]*b.shape[1], 1, b.shape[2], b.shape[3]))
    trans_b = [TO_IMG(x) for x in b]
    return trans_b

class ViTSCL(BasicModel):

    def __init__(self, args):
        super(ViTSCL, self).__init__(args)

        self.id2label = {'opt' + str(k): k for k in range(8)}
        self.label2id = {k: 'opt' + str(k) for k in range(8)}

        self.encoder = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                                num_labels=8,
                                                                id2label=self.id2label,
                                                                label2id=self.label2id,
                                                                output_hidden_states=True) 

        for name, param in self.encoder.named_parameters():
            if param.requires_grad:
                param.requires_grad = args.vit_requires_grad

        self.vec2image = Vec2Image(input_dim=args.vec2image_input_dim, output_dim=(1, args.scl_image_size, args.scl_image_size))

        self.scl = SCL(
            image_size = args.scl_image_size,                           # size of image
            set_size = args.scl_set_size,                               # number of questions + 1 answer
            conv_channels = args.scl_conv_channels,                     # convolutional channel progression, 1 for greyscale, 3 for rgb
            conv_output_dim = args.scl_conv_output_dim,                 # model dimension, the output dimension of the vision net
            attr_heads = args.scl_attr_heads,                           # number of attribute heads
            attr_net_hidden_dims = args.scl_attr_net_hidden_dims,       # attribute scatter transform MLP hidden dimension(s)
            rel_heads = args.scl_rel_heads,                             # number of relationship heads
            rel_net_hidden_dims = args.scl_rel_net_hidden_dims          # MLP for relationship net
        )

        self.decoder = SCLTrainingWrapper(self.scl)

        self.verbose = args.verbose

        self.optimizer = optim.Adam(self.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.epsilon)
        
        #self.Linear1 = torch.nn.Linear(768,768)
        #self.Linear2 = torch.nn.Linear(768,768)
        #self.Linear3 = torch.nn.Linear(768,768)
        #self.Linear4 = torch.nn.Linear(768,768)
        
        # Transforms
        self.grayscale_to_rgb = transforms.Lambda(lambda x: x.reshape((1, x.shape[0], x.shape[1])).repeat(3, 1, 1))
        self.reshape_input_batch = transforms.Lambda(lambda x: x.reshape((x.shape[0]*x.shape[1], x.shape[2], x.shape[3])))
        self.resize_to_vit_size = transforms.Resize((224, 224))
        
    def compute_loss(self, output, target, meta_target, meta_structure):
        pred = output[0]
        loss = F.cross_entropy(pred, target)
        return loss

    def forward(self, x, embedding, indicator):
        questions = x[:, :8, :, :]
        answers = x[:, 8:, :, :]

        q = self.reshape_input_batch(questions)
        a = self.reshape_input_batch(answers)

        q = torch.stack([self.resize_to_vit_size(self.grayscale_to_rgb(x)) for x in q])
        a = torch.stack([self.resize_to_vit_size(self.grayscale_to_rgb(x)) for x in a])

        if self.verbose:
            print("Shape of q =", q.shape)
            print("Shape of a =", q.shape)

        q_vit = self.encoder(q)['hidden_states'][-1][:, 0, :]
        a_vit = self.encoder(a)['hidden_states'][-1][:, 0, :]
        
        #q_vit = self.Linear1(q_vit)
        #q_vit = self.Linear2(q_vit)
        
        #a_vit = self.Linear1(a_vit)
        #a_vit = self.Linear2(a_vit)
        
        q_imgs = self.vec2image(q_vit)
        a_imgs = self.vec2image(a_vit)

        logits = self.decoder(q_imgs, a_imgs)

        return logits, None

### Training and Evalution

In [9]:
### Helper functions

def train(epoch, save_file):
    model.train()
    train_loss = 0
    accuracy = 0
    loss_all = 0.0
    acc_all = 0.0
    counter = 0
    for batch_idx, (image, target, meta_target, meta_structure, embedding, indicator) in enumerate(trainloader):
        counter += 1
        if args.cuda:
            image = image.cuda()
            target = target.cuda()
            meta_target = meta_target.cuda()
            meta_structure = meta_structure.cuda()
            embedding = embedding.cuda()
            indicator = indicator.cuda()
        loss, acc = model.train_(image, target, meta_target, meta_structure, embedding, indicator)
        save_str = 'Train: Epoch:{}, Batch:{}, Loss:{:.6f}, Acc:{:.4f}'.format(epoch, batch_idx, loss, acc)
        if counter % 20 == 0:
            print(save_str)
        with open(save_file, 'a') as f:
            f.write(save_str + "\n")
        loss_all += loss
        acc_all += acc
    if counter > 0:
        save_str = "Train_: Avg Training Loss: {:.6f}, Avg Training Acc: {:.6f}".format(
            loss_all/float(counter),
            (acc_all/float(counter))
        )
        print(save_str)
        with open(save_file, 'a') as f:
            f.write(save_str + "\n")
    return loss_all/float(counter), acc_all/float(counter)

def validate(epoch, save_file):
    model.eval()
    val_loss = 0
    accuracy = 0
    loss_all = 0.0
    acc_all = 0.0
    counter = 0
    batch_idx = 0
    for batch_idx, (image, target, meta_target, meta_structure, embedding, indicator) in enumerate(validloader):
        counter += 1
        if args.cuda:
            image = image.cuda()
            target = target.cuda()
            meta_target = meta_target.cuda()
            meta_structure = meta_structure.cuda()
            embedding = embedding.cuda()
            indicator = indicator.cuda()
        loss, acc = model.validate_(image, target, meta_target, meta_structure, embedding, indicator)
        loss_all += loss
        acc_all += acc
    if counter > 0:
        save_str = "Val_: Total Validation Loss: {:.6f}, Acc: {:.4f}".format((loss_all/float(counter)), (acc_all/float(counter)))
        print(save_str)
        with open(save_file, 'a') as f:
            f.write(save_str + "\n")
    return loss_all/float(counter), acc_all/float(counter)

def test(epoch, save_file):
    model.eval()
    accuracy = 0
    acc_all = 0.0
    counter = 0
    for batch_idx, (image, target, meta_target, meta_structure, embedding, indicator) in enumerate(testloader):
        counter += 1
        if args.cuda:
            image = image.cuda()
            target = target.cuda()
            meta_target = meta_target.cuda()
            meta_structure = meta_structure.cuda()
            embedding = embedding.cuda()
            indicator = indicator.cuda()
        acc = model.test_(image, target, meta_target, meta_structure, embedding, indicator)
        acc_all += acc
    if counter > 0:
        save_str = "Test_: Total Testing Acc: {:.4f}".format((acc_all / float(counter)))
        print(save_str)
        with open(save_file, 'a') as f:
            f.write(save_str + "\n")
    return acc_all/float(counter)

In [10]:
%%time
model = ViTSCL(args)
model = model.cuda()
#model = torch.nn.DataParallel(model, device_ids=[1, 2, 3])

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


CPU times: user 48.1 s, sys: 4.09 s, total: 52.2 s
Wall time: 1min 33s


In [12]:
model.load_model("/common/home/ab2253/Desktop/ckpt_res/", "42")

In [None]:
SAVE_FILE = "ViTSCL_take2_ep82" + time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime()) + "_" + str(args.perc_train)

for epoch in range(0, args.epochs):
    t0 = time.time()
    train(epoch, SAVE_FILE)
    avg_loss, avg_acc = validate(epoch, SAVE_FILE)
    test_acc = test(epoch, SAVE_FILE)
    model.save_model(args.save, epoch, avg_acc, avg_loss)
    print("Time taken = {:.4f} s\n".format(time.time() - t0))

Train: Epoch:0, Batch:19, Loss:0.285534, Acc:84.3750
Train: Epoch:0, Batch:39, Loss:0.165335, Acc:93.7500
Train: Epoch:0, Batch:59, Loss:0.477421, Acc:90.6250
Train: Epoch:0, Batch:79, Loss:0.141237, Acc:100.0000
Train: Epoch:0, Batch:99, Loss:0.313336, Acc:90.6250
Train: Epoch:0, Batch:119, Loss:0.136612, Acc:93.7500
Train: Epoch:0, Batch:139, Loss:0.175871, Acc:90.6250
Train: Epoch:0, Batch:159, Loss:0.227976, Acc:93.7500
Train: Epoch:0, Batch:179, Loss:0.102543, Acc:93.7500
Train: Epoch:0, Batch:199, Loss:0.258211, Acc:87.5000
Train: Epoch:0, Batch:219, Loss:0.303621, Acc:87.5000
Train: Epoch:0, Batch:239, Loss:0.115665, Acc:93.7500
Train: Epoch:0, Batch:259, Loss:0.149667, Acc:93.7500
Train: Epoch:0, Batch:279, Loss:0.310863, Acc:90.6250
Train: Epoch:0, Batch:299, Loss:0.191723, Acc:96.8750
Train: Epoch:0, Batch:319, Loss:0.469431, Acc:84.3750
Train: Epoch:0, Batch:339, Loss:0.210024, Acc:96.8750
Train: Epoch:0, Batch:359, Loss:0.234836, Acc:93.7500
Train: Epoch:0, Batch:379, Loss:

Train: Epoch:3, Batch:659, Loss:0.261186, Acc:84.3750
Train: Epoch:3, Batch:679, Loss:0.288735, Acc:87.5000
Train: Epoch:3, Batch:699, Loss:0.473454, Acc:84.3750
Train: Epoch:3, Batch:719, Loss:0.095166, Acc:96.8750
Train: Epoch:3, Batch:739, Loss:0.293026, Acc:87.5000
Train_: Avg Training Loss: 0.197340, Avg Training Acc: 92.404167
Val_: Total Validation Loss: 1.214170, Acc: 70.9125
Test_: Total Testing Acc: 71.3125
Time taken = 1975.4974 s

Train: Epoch:4, Batch:19, Loss:0.039414, Acc:100.0000
Train: Epoch:4, Batch:39, Loss:0.146688, Acc:96.8750
Train: Epoch:4, Batch:59, Loss:0.366090, Acc:81.2500
Train: Epoch:4, Batch:79, Loss:0.217482, Acc:90.6250
Train: Epoch:4, Batch:99, Loss:0.080096, Acc:96.8750
Train: Epoch:4, Batch:119, Loss:0.154731, Acc:90.6250
Train: Epoch:4, Batch:139, Loss:0.178881, Acc:93.7500
Train: Epoch:4, Batch:159, Loss:0.163252, Acc:90.6250
Train: Epoch:4, Batch:179, Loss:0.190591, Acc:93.7500
Train: Epoch:4, Batch:199, Loss:0.341482, Acc:84.3750
Train: Epoch:4, B

Train: Epoch:7, Batch:479, Loss:0.338328, Acc:90.6250
Train: Epoch:7, Batch:499, Loss:0.172356, Acc:87.5000
Train: Epoch:7, Batch:519, Loss:0.105574, Acc:96.8750
Train: Epoch:7, Batch:539, Loss:0.414000, Acc:81.2500
Train: Epoch:7, Batch:559, Loss:0.124856, Acc:93.7500
Train: Epoch:7, Batch:579, Loss:0.213294, Acc:90.6250
Train: Epoch:7, Batch:599, Loss:0.106667, Acc:96.8750
Train: Epoch:7, Batch:619, Loss:0.208634, Acc:90.6250
Train: Epoch:7, Batch:639, Loss:0.239683, Acc:84.3750
Train: Epoch:7, Batch:659, Loss:0.186554, Acc:87.5000
Train: Epoch:7, Batch:679, Loss:0.089331, Acc:96.8750
Train: Epoch:7, Batch:699, Loss:0.179390, Acc:93.7500
Train: Epoch:7, Batch:719, Loss:0.079456, Acc:96.8750
Train: Epoch:7, Batch:739, Loss:0.053988, Acc:96.8750
Train_: Avg Training Loss: 0.163890, Avg Training Acc: 93.758333
Val_: Total Validation Loss: 1.279206, Acc: 70.6125
Test_: Total Testing Acc: 71.3750
Time taken = 2013.3952 s

Train: Epoch:8, Batch:19, Loss:0.180141, Acc:93.7500
Train: Epoch:8

Train: Epoch:11, Batch:279, Loss:0.091661, Acc:100.0000
Train: Epoch:11, Batch:299, Loss:0.419169, Acc:87.5000
Train: Epoch:11, Batch:319, Loss:0.156915, Acc:96.8750
Train: Epoch:11, Batch:339, Loss:0.074185, Acc:100.0000
Train: Epoch:11, Batch:359, Loss:0.263249, Acc:87.5000
Train: Epoch:11, Batch:379, Loss:0.029427, Acc:100.0000
Train: Epoch:11, Batch:399, Loss:0.133304, Acc:93.7500
Train: Epoch:11, Batch:419, Loss:0.347372, Acc:84.3750
Train: Epoch:11, Batch:439, Loss:0.108162, Acc:96.8750
Train: Epoch:11, Batch:459, Loss:0.296727, Acc:93.7500
Train: Epoch:11, Batch:479, Loss:0.198997, Acc:87.5000
Train: Epoch:11, Batch:499, Loss:0.245256, Acc:87.5000
Train: Epoch:11, Batch:519, Loss:0.229080, Acc:93.7500
Train: Epoch:11, Batch:539, Loss:0.193307, Acc:87.5000
Train: Epoch:11, Batch:559, Loss:0.153056, Acc:93.7500
Train: Epoch:11, Batch:579, Loss:0.223167, Acc:93.7500
Train: Epoch:11, Batch:599, Loss:0.127017, Acc:96.8750
Train: Epoch:11, Batch:619, Loss:0.198336, Acc:90.6250
Train: 

Train: Epoch:15, Batch:59, Loss:0.168957, Acc:90.6250
Train: Epoch:15, Batch:79, Loss:0.250407, Acc:90.6250
Train: Epoch:15, Batch:99, Loss:0.174407, Acc:93.7500
Train: Epoch:15, Batch:119, Loss:0.108108, Acc:96.8750
Train: Epoch:15, Batch:139, Loss:0.160835, Acc:93.7500
Train: Epoch:15, Batch:159, Loss:0.040671, Acc:100.0000
Train: Epoch:15, Batch:179, Loss:0.078708, Acc:100.0000
Train: Epoch:15, Batch:199, Loss:0.067304, Acc:100.0000
Train: Epoch:15, Batch:219, Loss:0.266765, Acc:87.5000
Train: Epoch:15, Batch:239, Loss:0.113065, Acc:96.8750
Train: Epoch:15, Batch:259, Loss:0.241152, Acc:90.6250
Train: Epoch:15, Batch:279, Loss:0.064644, Acc:96.8750
Train: Epoch:15, Batch:299, Loss:0.174059, Acc:90.6250
Train: Epoch:15, Batch:319, Loss:0.207274, Acc:90.6250
Train: Epoch:15, Batch:339, Loss:0.099867, Acc:96.8750
Train: Epoch:15, Batch:359, Loss:0.030084, Acc:100.0000
Train: Epoch:15, Batch:379, Loss:0.067621, Acc:96.8750
Train: Epoch:15, Batch:399, Loss:0.121850, Acc:93.7500
Train: Ep

Train: Epoch:18, Batch:639, Loss:0.332674, Acc:87.5000
Train: Epoch:18, Batch:659, Loss:0.285838, Acc:84.3750
Train: Epoch:18, Batch:679, Loss:0.495244, Acc:75.0000
Train: Epoch:18, Batch:699, Loss:0.292365, Acc:90.6250
Train: Epoch:18, Batch:719, Loss:0.120784, Acc:90.6250
Train: Epoch:18, Batch:739, Loss:0.305413, Acc:90.6250
Train_: Avg Training Loss: 0.150235, Avg Training Acc: 94.433333
Val_: Total Validation Loss: 1.381826, Acc: 71.1500
Test_: Total Testing Acc: 71.0125
Time taken = 1975.9838 s

Train: Epoch:19, Batch:19, Loss:0.097245, Acc:96.8750
Train: Epoch:19, Batch:39, Loss:0.132470, Acc:96.8750
Train: Epoch:19, Batch:59, Loss:0.210798, Acc:90.6250
Train: Epoch:19, Batch:79, Loss:0.261459, Acc:90.6250
Train: Epoch:19, Batch:99, Loss:0.207239, Acc:87.5000
Train: Epoch:19, Batch:119, Loss:0.351924, Acc:90.6250
Train: Epoch:19, Batch:139, Loss:0.112347, Acc:90.6250
Train: Epoch:19, Batch:159, Loss:0.333171, Acc:87.5000
Train: Epoch:19, Batch:179, Loss:0.086099, Acc:96.8750
Tra

Train: Epoch:22, Batch:419, Loss:0.181130, Acc:96.8750
Train: Epoch:22, Batch:439, Loss:0.254747, Acc:90.6250
Train: Epoch:22, Batch:459, Loss:0.454616, Acc:87.5000
Train: Epoch:22, Batch:479, Loss:0.047552, Acc:100.0000
Train: Epoch:22, Batch:499, Loss:0.369806, Acc:87.5000
Train: Epoch:22, Batch:519, Loss:0.168620, Acc:93.7500
Train: Epoch:22, Batch:539, Loss:0.446574, Acc:90.6250
Train: Epoch:22, Batch:559, Loss:0.308873, Acc:93.7500
Train: Epoch:22, Batch:579, Loss:0.136470, Acc:96.8750
Train: Epoch:22, Batch:599, Loss:0.202033, Acc:93.7500
Train: Epoch:22, Batch:619, Loss:0.119431, Acc:96.8750
Train: Epoch:22, Batch:639, Loss:0.081231, Acc:100.0000
Train: Epoch:22, Batch:659, Loss:0.103586, Acc:100.0000
Train: Epoch:22, Batch:679, Loss:0.073002, Acc:96.8750
Train: Epoch:22, Batch:699, Loss:0.061366, Acc:96.8750
Train: Epoch:22, Batch:719, Loss:0.151589, Acc:90.6250
Train: Epoch:22, Batch:739, Loss:0.236522, Acc:93.7500
Train_: Avg Training Loss: 0.161450, Avg Training Acc: 94.2208

Train: Epoch:26, Batch:179, Loss:0.126661, Acc:93.7500
Train: Epoch:26, Batch:199, Loss:0.098232, Acc:96.8750
Train: Epoch:26, Batch:219, Loss:0.081156, Acc:96.8750
Train: Epoch:26, Batch:239, Loss:0.095217, Acc:100.0000
Train: Epoch:26, Batch:259, Loss:0.064225, Acc:96.8750
Train: Epoch:26, Batch:279, Loss:0.014909, Acc:100.0000
Train: Epoch:26, Batch:299, Loss:0.174628, Acc:93.7500
Train: Epoch:26, Batch:319, Loss:0.063786, Acc:100.0000
Train: Epoch:26, Batch:339, Loss:0.100670, Acc:96.8750
Train: Epoch:26, Batch:359, Loss:0.023455, Acc:100.0000
Train: Epoch:26, Batch:379, Loss:0.050989, Acc:100.0000
Train: Epoch:26, Batch:399, Loss:0.117167, Acc:90.6250
Train: Epoch:26, Batch:419, Loss:0.167161, Acc:93.7500
Train: Epoch:26, Batch:439, Loss:0.075763, Acc:96.8750
Train: Epoch:26, Batch:459, Loss:0.168913, Acc:87.5000
Train: Epoch:26, Batch:479, Loss:0.110947, Acc:100.0000
Train: Epoch:26, Batch:499, Loss:0.344967, Acc:93.7500
Train: Epoch:26, Batch:519, Loss:0.173951, Acc:93.7500
Trai

Train_: Avg Training Loss: 0.134692, Avg Training Acc: 95.179167
Val_: Total Validation Loss: 1.520793, Acc: 71.3000
Test_: Total Testing Acc: 69.9125
Time taken = 1954.3317 s

Train: Epoch:30, Batch:19, Loss:0.012916, Acc:100.0000
Train: Epoch:30, Batch:39, Loss:0.155400, Acc:93.7500
Train: Epoch:30, Batch:59, Loss:0.077361, Acc:93.7500
Train: Epoch:30, Batch:79, Loss:0.127768, Acc:93.7500
Train: Epoch:30, Batch:99, Loss:0.210140, Acc:90.6250
Train: Epoch:30, Batch:119, Loss:0.087176, Acc:96.8750
Train: Epoch:30, Batch:139, Loss:0.300023, Acc:90.6250
Train: Epoch:30, Batch:159, Loss:0.097147, Acc:96.8750
Train: Epoch:30, Batch:179, Loss:0.314534, Acc:84.3750
Train: Epoch:30, Batch:199, Loss:0.132021, Acc:90.6250
Train: Epoch:30, Batch:219, Loss:0.112645, Acc:93.7500
Train: Epoch:30, Batch:239, Loss:0.187866, Acc:90.6250
Train: Epoch:30, Batch:259, Loss:0.087659, Acc:96.8750
Train: Epoch:30, Batch:279, Loss:0.109575, Acc:90.6250
Train: Epoch:30, Batch:299, Loss:0.100566, Acc:93.7500
Tr

Train: Epoch:33, Batch:539, Loss:0.075230, Acc:100.0000
Train: Epoch:33, Batch:559, Loss:0.192246, Acc:96.8750
Train: Epoch:33, Batch:579, Loss:0.087482, Acc:96.8750
Train: Epoch:33, Batch:599, Loss:0.101618, Acc:100.0000
Train: Epoch:33, Batch:619, Loss:0.061027, Acc:96.8750
Train: Epoch:33, Batch:639, Loss:0.090388, Acc:90.6250
Train: Epoch:33, Batch:659, Loss:0.187766, Acc:90.6250
Train: Epoch:33, Batch:679, Loss:0.105602, Acc:93.7500
Train: Epoch:33, Batch:699, Loss:0.197296, Acc:93.7500
Train: Epoch:33, Batch:719, Loss:0.076456, Acc:96.8750
Train: Epoch:33, Batch:739, Loss:0.317078, Acc:93.7500
Train_: Avg Training Loss: 0.131107, Avg Training Acc: 95.145833
Val_: Total Validation Loss: 1.514687, Acc: 71.8250
Test_: Total Testing Acc: 71.9875
Time taken = 1957.9116 s

Train: Epoch:34, Batch:19, Loss:0.113933, Acc:96.8750
Train: Epoch:34, Batch:39, Loss:0.065536, Acc:100.0000
Train: Epoch:34, Batch:59, Loss:0.160028, Acc:96.8750
Train: Epoch:34, Batch:79, Loss:0.114250, Acc:96.8750

Train: Epoch:37, Batch:299, Loss:0.069275, Acc:100.0000
Train: Epoch:37, Batch:319, Loss:0.159864, Acc:93.7500
Train: Epoch:37, Batch:339, Loss:0.123273, Acc:93.7500
Train: Epoch:37, Batch:359, Loss:0.157154, Acc:93.7500
Train: Epoch:37, Batch:379, Loss:0.085143, Acc:96.8750
Train: Epoch:37, Batch:399, Loss:0.232337, Acc:96.8750
Train: Epoch:37, Batch:419, Loss:0.154679, Acc:93.7500
Train: Epoch:37, Batch:439, Loss:0.024970, Acc:100.0000
Train: Epoch:37, Batch:459, Loss:0.068161, Acc:96.8750
Train: Epoch:37, Batch:479, Loss:0.080114, Acc:96.8750
Train: Epoch:37, Batch:499, Loss:0.454843, Acc:96.8750
Train: Epoch:37, Batch:519, Loss:0.045693, Acc:100.0000
Train: Epoch:37, Batch:539, Loss:0.172166, Acc:96.8750
Train: Epoch:37, Batch:559, Loss:0.136394, Acc:93.7500
Train: Epoch:37, Batch:579, Loss:0.066865, Acc:100.0000
Train: Epoch:37, Batch:599, Loss:0.363941, Acc:90.6250
Train: Epoch:37, Batch:619, Loss:0.230858, Acc:93.7500
Train: Epoch:37, Batch:639, Loss:0.106999, Acc:93.7500
Train:

Train: Epoch:41, Batch:59, Loss:0.066985, Acc:96.8750
Train: Epoch:41, Batch:79, Loss:0.021774, Acc:100.0000
Train: Epoch:41, Batch:99, Loss:0.248450, Acc:90.6250
Train: Epoch:41, Batch:119, Loss:0.094590, Acc:96.8750
Train: Epoch:41, Batch:139, Loss:0.116728, Acc:96.8750
Train: Epoch:41, Batch:159, Loss:0.192431, Acc:93.7500
Train: Epoch:41, Batch:179, Loss:0.205406, Acc:90.6250
Train: Epoch:41, Batch:199, Loss:0.028106, Acc:100.0000
Train: Epoch:41, Batch:219, Loss:0.145676, Acc:96.8750
Train: Epoch:41, Batch:239, Loss:0.099538, Acc:96.8750
Train: Epoch:41, Batch:259, Loss:0.091636, Acc:93.7500
Train: Epoch:41, Batch:279, Loss:0.120847, Acc:96.8750
Train: Epoch:41, Batch:299, Loss:0.006287, Acc:100.0000
Train: Epoch:41, Batch:319, Loss:0.139299, Acc:96.8750
Train: Epoch:41, Batch:339, Loss:0.025027, Acc:100.0000
Train: Epoch:41, Batch:359, Loss:0.068495, Acc:100.0000
Train: Epoch:41, Batch:379, Loss:0.198013, Acc:93.7500
Train: Epoch:41, Batch:399, Loss:0.008581, Acc:100.0000
Train: 

In [13]:
testset = dataset(args.path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))

In [14]:
center_single_fnames = [x for x in testset.file_names if '/center_single/' in x]
distribute_four_fnames = [x for x in testset.file_names if '/distribute_four/' in x]
in_distribute_four_out_center_single_fnames = \
    [x for x in testset.file_names if '/in_distribute_four_out_center_single/' in x]
left_center_single_right_center_single_fnames = \
    [x for x in testset.file_names if '/left_center_single_right_center_single/' in x]

In [15]:
cs_set = dataset(args.path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
cs_set.file_names = center_single_fnames
cs_testloader = DataLoader(cs_set, batch_size=args.batch_size, shuffle=False, num_workers=80)

df_set = dataset(args.path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
df_set.file_names = distribute_four_fnames
df_testloader = DataLoader(df_set, batch_size=args.batch_size, shuffle=False, num_workers=80)

idfo_set = dataset(args.path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
idfo_set.file_names = in_distribute_four_out_center_single_fnames
idfo_testloader = DataLoader(idfo_set, batch_size=args.batch_size, shuffle=False, num_workers=80)

lcsr_set = dataset(args.path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
lcsr_set.file_names = left_center_single_right_center_single_fnames
lcsr_testloader = DataLoader(lcsr_set, batch_size=args.batch_size, shuffle=False, num_workers=80)

In [16]:
def test_sep(epoch, tl, save_file):
    model.eval()
    accuracy = 0

    acc_all = 0.0
    counter = 0
    for batch_idx, (image, target, meta_target, meta_structure, embedding, indicator) in enumerate(tl):
        counter += 1
        if args.cuda:
            image = image.cuda()
            target = target.cuda()
            meta_target = meta_target.cuda()
            meta_structure = meta_structure.cuda()
            embedding = embedding.cuda()
            indicator = indicator.cuda()
        acc = model.test_(image, target, meta_target, meta_structure, embedding, indicator)
        # print('Test: Epoch:{}, Batch:{}, Acc:{:.4f}.'.format(epoch, batch_idx, acc))  
        acc_all += acc
    if counter > 0:
        save_str = "Test_: Total Testing Acc: {:.4f}".format(acc_all / float(counter))
        print(save_str)
        with open(save_file, 'a') as f:
            f.write(save_str + "\n")
    return acc_all/float(counter)

In [19]:
SAVE_FILE_SEP = SAVE_FILE + "_sep"

test_acc_cs = test_sep(epoch, cs_testloader, SAVE_FILE_SEP)
test_acc_df = test_sep(epoch, df_testloader, SAVE_FILE_SEP)
test_acc_idfo = test_sep(epoch, idfo_testloader, SAVE_FILE_SEP)
test_acc_lcsr = test_sep(epoch, lcsr_testloader, SAVE_FILE_SEP)

Test_: Total Testing Acc: 93.6012
Test_: Total Testing Acc: 72.7679
Test_: Total Testing Acc: 66.7163
Test_: Total Testing Acc: 53.0258
