image classification에 word embedding을 사용하면 어떻게 될까?

-> CIFAR100 데이터의 레이블을 embedding된 vector로 바꾼 후에 그러한 vector를 예측하도록 regression 한 다음에, classification 한다.

image -> embedding -> label 의 과정을 거쳐 classification 한다.

In [268]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("using", device)

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

using cuda


In [269]:
batch_size = 64
vector_size = 50
epoch = 500

test_len = 10
test_num = 50

lr = 0.002

In [270]:
#loading glove
embed_dict = {}
with open("./glove6B/glove6B50d.txt", 'r', encoding="utf-8") as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = numpy.asarray(values[1:], "float32")
        embed_dict[word] = vector

In [271]:
#labels of cifar100, some words that are not in glove are replaced to None
cifar_classnum_to_label=["apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "cra", "crocodile", "cup", "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck", "pine_tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower", "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm"]

In [272]:
cifar_classnum_to_embed = []
for label in cifar_classnum_to_label:
    if(label in embed_dict):
        cifar_classnum_to_embed.append(embed_dict[label])
    else:
        words = label.split('_')
        embed = embed_dict[words[0]] + embed_dict[words[1]]
        cifar_classnum_to_embed.append(embed)

cifar_classnum_to_embed = torch.FloatTensor(cifar_classnum_to_embed).to(device)

In [273]:
cifar_train = torchvision.datasets.CIFAR100(
    root = '../CIFAR100_data',
    train = True, 
    transform = torchvision.transforms.ToTensor(), 
    download = True
)

cifar_test = torchvision.datasets.CIFAR100(
    root = '../CIFAR100_data',
    train = False, 
    transform = torchvision.transforms.ToTensor(), 
    download = True
)

Files already downloaded and verified
Files already downloaded and verified


In [274]:
data_loader_train = torch.utils.data.DataLoader(cifar_train, batch_size = batch_size, shuffle = True)
data_loader_test = torch.utils.data.DataLoader(cifar_test, batch_size = batch_size, shuffle = True)

In [275]:
class image_to_embed(torch.nn.Module):
    def __init__(self, in_size = 3, channels = [16, 16, 32, 32, 64, 64, 128, 128], strides = [1, 1, 2, 1, 2, 1, 2, 1], half_size = False, linear_in = 128 * 4 * 4, linear_hidden = 128, linear_num = 2, linear_out = vector_size):
        super(image_to_embed, self).__init__()
        self.linear_in = linear_in
        block_list = []

        for i in range(len(channels)):
            out_size = channels[i]
            block_list.append(resblock(in_size, in_size, out_size, strides[i]))
            in_size = out_size

        self.blocks = torch.nn.Sequential(*block_list)
        self.relu = torch.nn.ReLU()

        linear_list = []
        linear_list.append(torch.nn.Linear(linear_in, linear_hidden))
        for i in range(linear_num):
            linear_list.append(torch.nn.Linear(linear_hidden, linear_hidden))
            linear_list.append(self.relu)
        
        linear_list.append(torch.nn.Linear(linear_hidden, linear_out))

        self.linears = torch.nn.Sequential(*linear_list)



    def forward(self, x):
        out = self.blocks(x)
        out = out.view(-1, self.linear_in)
        out = self.linears(out)
        return out
    

class resblock(torch.nn.Module): # in, hidden_channel은 말 그대로 channel 수, half_size는 output의 크기를 input 그대로 유지할 것인지, stride를 2로 할 것인지 결정
    def __init__(self, in_channel, hidden_channel, out_channel, stride = 1):
        super(resblock, self).__init__()
        self.in_channel = in_channel
        self.hidden_channel = hidden_channel
        self.out_channel = out_channel
        self.downsample = None

        if(stride != 1 or in_channel != out_channel):
            self.downsample = torch.nn.Conv2d(in_channel, out_channel, kernel_size=1, stride = stride)
        self.bnskip = torch.nn.BatchNorm2d(out_channel)
        

        self.conv1x1_in = torch.nn.Conv2d(in_channel, hidden_channel, kernel_size=1)
        self.bn1 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv3x3=torch.nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, padding=1, stride = stride)
        self.bn2 = torch.nn.BatchNorm2d(hidden_channel)

        self.conv1x1_out = torch.nn.Conv2d(hidden_channel, out_channel, kernel_size=1)
        self.bn3 = torch.nn.BatchNorm2d(out_channel)

        self.relu = torch.nn.ReLU(inplace=True)
    
    def forward(self, x):
        out = self.conv1x1_in(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv3x3(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv1x1_out(out)
        out = self.bn3(out)

        if(self.downsample != None):
            identity = self.downsample(x)
        else:
            identity = x
        
        out += identity
        out = self.relu(out)

        return out

In [276]:
model = image_to_embed().to(device)
optim = torch.optim.Adam(model.parameters(), lr = lr)
criterion = torch.nn.MSELoss()

In [280]:
num = len(data_loader_train)
for i in range(epoch):
    avgloss = 0
    for datax, datay in data_loader_train:
        datax = datax.to(device)

        datay = cifar_classnum_to_embed[datay]
        predict = model(datax)

        loss = criterion(predict, datay)
        optim.zero_grad()
        loss.backward()
        optim.step()

        avgloss += loss.item()

    avgloss /= num
    print("EPOCH", i, "LOSS :", avgloss)

        

EPOCH 0 LOSS : 0.3948608104835081


KeyboardInterrupt: 