In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torch.utils.data as data
import torchvision.models as models
import torchvision.utils as v_utils
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, TextArea
from matplotlib.cbook import get_sample_data
from PIL import ImageFile
import os

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True
image_size = 256
PATH = "./cropWebtoonImg/"

In [None]:
data = datasets.ImageFolder(PATH,transform= transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        ]))

print(data.class_to_idx)

class_to_idx = data.class_to_idx
idx_to_class = {}
for key, value in enumerate(class_to_idx):
    idx_to_class[key] = value
    
print(idx_to_class)

img_list = []
for i in data.imgs:
    img_list.append(i[0])

#img_list2 = []

#for img in os.listdir('/content/drive/My Drive/dataset/thumnail'):
#    img_list2.append(os.path.join('/content/drive/My Drive/dataset/thumnail',img))
#img_list2.sort()

In [None]:
resnet = models.resnet50(pretrained=True)

class Resnet(nn.Module):
    def __init__(self):
        super(Resnet,self).__init__()
        self.layer0 = nn.Sequential(*list(resnet.children())[0:1])
        self.layer1 = nn.Sequential(*list(resnet.children())[1:4])
        self.layer2 = nn.Sequential(*list(resnet.children())[4:5])
        self.layer3 = nn.Sequential(*list(resnet.children())[5:6])
        #self.layer4 = nn.Sequential(*list(resnet.children())[6:7])
        #self.layer5 = nn.Sequential(*list(resnet.children())[7:8])

    def forward(self,x):
        out_0 = self.layer0(x)
        out_1 = self.layer1(out_0)
        out_2 = self.layer2(out_1)
        out_3 = self.layer3(out_2)
        #out_4 = self.layer4(out_3)
        #out_5 = self.layer5(out_4)

        return out_0, out_1, out_2, out_3, # out_4, out_5

In [None]:
class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        F = input.view(b, c, h*w)
        G = torch.bmm(F, F.transpose(1,2)) 
        return G

class GramMSELoss(nn.Module):
    def forward(self, input, target):
        out = nn.MSELoss()(GramMatrix()(input), target)
        return out

In [None]:
resnet = Resnet().cuda()
for param in resnet.parameters():
    param.requires_grad = False

In [None]:
total_arr = []
label_arr = []

for idx,(image,label) in enumerate(data):
    i = image
    i = i.view(-1,i.size()[0],i.size()[1],i.size()[2])

    style_target = list(GramMatrix().cuda()(i) for i in resnet(i))

    arr = torch.cat([style_target[0].view(-1),style_target[1].view(-1),style_target[2].view(-1),style_target[3].view(-1)],0)
    gram = arr.cpu().data.numpy().reshape(1,-1)

    total_arr.append(gram.reshape(-1))
    label_arr.append(label)

    if idx % 50 == 0 and idx != 0:
        print(f'{idx} images style feature extracted..[{round(idx / len(data), 2) * 100}%]')
print('Image style feature extraction done.')

In [None]:
model = TSNE(n_components=2, init='pca',random_state=0, verbose=3, perplexity=100)
result = model.fit_transform(total_arr)

In [None]:
print(result[:2291][0].mean(), result[:2291][1].mean())
print(label_arr[2291])
print(result[2292:][0].mean(),result[2292:][1].mean())
print(label_arr[2292])

In [None]:
def imscatter(x, y, image, ax=None, zoom=1, show_by_thumnail=False, title='webtoon'):
    if ax is None:
        ax = plt.gca()
    try:
        image = plt.imread(image)
    except TypeError:
        # Likely already an array...
        pass
    im = OffsetImage(image, zoom=zoom)

    # Convert inputs to arrays with at least one dimension.
    x, y = np.atleast_1d(x, y)
    
    artists = []
    for x0, y0 in zip(x, y):
        ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
        

        if show_by_thumnail:
            offsetbox = TextArea(title, minimumdescent=False)
            ac = AnnotationBbox(offsetbox, (x0, y0),
                        xybox=(20, -40),
                        xycoords='data',
                        boxcoords="offset points")
            artists.append(ax.add_artist(ac))
        artists.append(ax.add_artist(ab))

    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists

In [None]:
plt.figure(figsize=(20, 12))

for i in range(len(result)):
    img_path = img_list[i]
    imscatter(result[i,0],result[i,1], image=img_path, zoom=0.2)
plt.show()

In [None]:
avg_list = []
scatter_x = result[:, 0]
scatter_y = result[:, 1]
group = np.array(label_arr)

for g in np.unique(group):
    i = np.where(group==g)
    x_avg = np.mean(scatter_x[i])
    y_avg = np.mean(scatter_y[i])
    avg_list.append((x_avg, y_avg))

In [None]:
plt.figure(figsize=(20, 12))

for i in range(len(avg_list)):
    img_path = img_list2[i]
    imscatter(avg_list[i][0],avg_list[i][1], image=img_path,zoom=0.6, show_by_thumnail=True, title=idx_to_class[i])
plt.show()