In [1]:
'''
'''
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets
import torch.optim as optim
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init

from torch.utils.data import DataLoader, Dataset

cuda = torch.cuda.is_available()
cuda

True

## Data Loading

In [2]:
class PDataset(Dataset):
    def __init__(self, data_file_list, label_file_list):
        self.data_file_list = data_file_list
        self.label_file_list = label_file_list
        
        self.data1 = np.load(data_file_list[0])
        self.data2 = np.load(data_file_list[1])
        self.data3 = np.load(data_file_list[2])
        self.label = np.load(label_file_list[0])
        
        #self.label2 = np.load(label_file_list[1])
        
        #self.label3 = np.load(label_file_list[2])
        
        #self.n_class = len(np.unique(self.label))
    
    def __len__(self):
        return self.data1.shape[0]
    
    def __getitem__(self, index):
        stack1 = np.hstack((self.data1[index], self.data2[index]))
        embedding = np.hstack((stack1, self.data3[index]))
        embedding = torch.from_numpy(embedding)
        
        label = self.label[index]
        #label = torch.Tensor(label)
        
        return embedding, label



In [3]:
trainset = PDataset(["shufflenet_embedding.npy", "mobilenet_embedding.npy", "resnet_embedding.npy"]
                     ,["shufflenet_embedding_label.npy","mobilenet_embedding_label.npy","resnet_embedding_label.npy"])
train_dataloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=3,drop_last=False)

## Embedding Classification Network

In [32]:
class Simple(nn.Module):
    def __init__(self, n_in = 512*3, n_out = 2300):
        super(Simple, self).__init__()
        #self.conv1 = nn.Conv1d(in_channels=1, out_channels=256, kernel_size=8, stride=1)
        
        self.fc1 = nn.Linear(n_in, n_in*2, bias = True)
        #self.fc1 = nn.Linear(256, n_in, bias = True)
        self.dropout1 = nn.Dropout(p=0.1)
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(n_in*2)
        #self.bn1 = nn.BatchNorm1d(n_in)
        
        self.fc2 = nn.Linear(n_in*2, n_in*4, bias = True)
        #self.fc2 = nn.Linear(n_in, n_out, bias = True)
        self.dropout2 = nn.Dropout(p=0.2)
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm1d(n_in*4)

        self.fc0 = nn.Linear(n_in*4, n_in*4, bias = True)
        self.relu0 = nn.ReLU()
        self.bn0 = nn.BatchNorm1d(n_in*4)
        
        self.fc3 = nn.Linear(n_in*4, n_out, bias = True)
        #self.logprob = nn.LogSoftmax(dim=1)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
  
    def forward(self, x):
        #x = x.unsqueeze(1)
        #x = self.conv1(x)  #(N, C, K)
        #x = x.mean([2])
        x = self.fc1(x)
        x = self.dropout1(x)
        x = self.relu1(x)
        x = self.bn1(x)
        
        x = self.fc2(x)
        x = self.dropout2(x)
        x = self.relu2(x)
        x = self.bn2(x)
        
        x = self.fc0(x)
        x = self.relu0(x)
        x = self.bn0(x)
        
        x = self.fc3(x)
        #x = self.logprob(x)
        
        return x

In [33]:
network = Simple()
print(network)

Simple(
  (fc1): Linear(in_features=1536, out_features=3072, bias=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (relu1): ReLU()
  (bn1): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=3072, out_features=6144, bias=True)
  (dropout2): Dropout(p=0.2, inplace=False)
  (relu2): ReLU()
  (bn2): BatchNorm1d(6144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc0): Linear(in_features=6144, out_features=6144, bias=True)
  (relu0): ReLU()
  (bn0): BatchNorm1d(6144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=6144, out_features=2300, bias=True)
)


In [34]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(network.parameters(), lr = 5e-3)
#optimizer = torch.optim.Adam(network.parameters())

In [35]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
network.to(device)
network.train()
for i in range(10):
    print("this is "+str(i)+"th iteration")
    for epoch in range(20):
        avg_loss = 0
        accuracy = 0
        total = 0
        for batch_num, (data, label) in enumerate(train_dataloader):

            #data = Variable(data.float(), requires_grad = False)
            #label = Variable(label.long(), requires_grad = False)
            #print(label.size())
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            pred = network(data.float())

            loss = criterion(pred, label.view(-1).long())

            loss.backward()

            optimizer.step()

            avg_loss += loss.item()
            if batch_num%1000 == 999:
                print('Epoch: {}\tBatch: {}\tAvg-Loss: {:.4f}'.format(epoch+1, batch_num+1, avg_loss/1000))
                avg_loss = 0.0
                
            _, pred_labels = torch.max(F.softmax(pred, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, label)).item()
            total += len(label)
            
            del data
            del label
            del loss
        print(accuracy / total)
        

        

this is 0th iteration
Epoch: 1	Batch: 1000	Avg-Loss: 5.7580
Epoch: 1	Batch: 2000	Avg-Loss: 3.4206
Epoch: 1	Batch: 3000	Avg-Loss: 2.5051
Epoch: 1	Batch: 4000	Avg-Loss: 1.9984
Epoch: 1	Batch: 5000	Avg-Loss: 1.6915
Epoch: 1	Batch: 6000	Avg-Loss: 1.4922
0.5564261683334266
Epoch: 2	Batch: 1000	Avg-Loss: 1.2084
Epoch: 2	Batch: 2000	Avg-Loss: 1.1260
Epoch: 2	Batch: 3000	Avg-Loss: 1.0560
Epoch: 2	Batch: 4000	Avg-Loss: 1.0030
Epoch: 2	Batch: 5000	Avg-Loss: 0.9545
Epoch: 2	Batch: 6000	Avg-Loss: 0.9085
0.7993161865052046
Epoch: 3	Batch: 1000	Avg-Loss: 0.7844
Epoch: 3	Batch: 2000	Avg-Loss: 0.7670
Epoch: 3	Batch: 3000	Avg-Loss: 0.7568
Epoch: 3	Batch: 4000	Avg-Loss: 0.7358
Epoch: 3	Batch: 5000	Avg-Loss: 0.7146
Epoch: 3	Batch: 6000	Avg-Loss: 0.6996
0.8506228760061983
Epoch: 4	Batch: 1000	Avg-Loss: 0.6112
Epoch: 4	Batch: 2000	Avg-Loss: 0.6050
Epoch: 4	Batch: 3000	Avg-Loss: 0.5969
Epoch: 4	Batch: 4000	Avg-Loss: 0.5898
Epoch: 4	Batch: 5000	Avg-Loss: 0.5870
Epoch: 4	Batch: 6000	Avg-Loss: 0.5772
0.8791272

KeyboardInterrupt: 

In [36]:
torch.save(network.state_dict(),'embedding_MLP4.pth')

In [9]:
network.parameters

<bound method Module.parameters of Simple(
  (fc1): Linear(in_features=1536, out_features=3072, bias=True)
  (relu1): ReLU()
  (bn1): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=3072, out_features=6144, bias=True)
  (relu2): ReLU()
  (bn2): BatchNorm1d(6144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=6144, out_features=2300, bias=True)
)>

In [10]:
print(network.modules)

<bound method Module.modules of Simple(
  (fc1): Linear(in_features=1536, out_features=3072, bias=True)
  (relu1): ReLU()
  (bn1): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=3072, out_features=6144, bias=True)
  (relu2): ReLU()
  (bn2): BatchNorm1d(6144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=6144, out_features=2300, bias=True)
)>


In [10]:
network.load_state_dict(torch.load('embedding_MLP3.pth'))
optimizer = torch.optim.SGD(network.parameters(), lr = 1e-3)

In [11]:
network.to(device)
network.train()
for i in range(10):
    print("this is "+str(i)+"th iteration")
    for epoch in range(20):
        avg_loss = 0
        accuracy = 0
        total = 0
        for batch_num, (data, label) in enumerate(train_dataloader):

            #data = Variable(data.float(), requires_grad = False)
            #label = Variable(label.long(), requires_grad = False)
            #print(label.size())
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            pred = network(data.float())

            loss = criterion(pred, label.view(-1).long())

            loss.backward()

            optimizer.step()

            avg_loss += loss.item()
            if batch_num%1000 == 999:
                print('Epoch: {}\tBatch: {}\tAvg-Loss: {:.4f}'.format(epoch+1, batch_num+1, avg_loss/1000))
                avg_loss = 0.0
                
            _, pred_labels = torch.max(F.softmax(pred, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, label)).item()
            total += len(label)
            
            del data
            del label
            del loss
        print(accuracy / total)

this is 0th iteration
Epoch: 1	Batch: 1000	Avg-Loss: 0.2351
Epoch: 1	Batch: 2000	Avg-Loss: 0.2305
Epoch: 1	Batch: 3000	Avg-Loss: 0.2288
Epoch: 1	Batch: 4000	Avg-Loss: 0.2272
Epoch: 1	Batch: 5000	Avg-Loss: 0.2264
Epoch: 1	Batch: 6000	Avg-Loss: 0.2276
0.9498208364856219
Epoch: 2	Batch: 1000	Avg-Loss: 0.2156
Epoch: 2	Batch: 2000	Avg-Loss: 0.2210
Epoch: 2	Batch: 3000	Avg-Loss: 0.2214
Epoch: 2	Batch: 4000	Avg-Loss: 0.2191
Epoch: 2	Batch: 5000	Avg-Loss: 0.2213
Epoch: 2	Batch: 6000	Avg-Loss: 0.2210
0.952022370504796
Epoch: 3	Batch: 1000	Avg-Loss: 0.2138
Epoch: 3	Batch: 2000	Avg-Loss: 0.2106
Epoch: 3	Batch: 3000	Avg-Loss: 0.2150
Epoch: 3	Batch: 4000	Avg-Loss: 0.2156
Epoch: 3	Batch: 5000	Avg-Loss: 0.2153
Epoch: 3	Batch: 6000	Avg-Loss: 0.2133
0.9537081860575999
Epoch: 4	Batch: 1000	Avg-Loss: 0.2080
Epoch: 4	Batch: 2000	Avg-Loss: 0.2136
Epoch: 4	Batch: 3000	Avg-Loss: 0.2115
Epoch: 4	Batch: 4000	Avg-Loss: 0.2080
Epoch: 4	Batch: 5000	Avg-Loss: 0.2105
Epoch: 4	Batch: 6000	Avg-Loss: 0.2120
0.95449879

KeyboardInterrupt: 