In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.transforms import functional as F

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(10368, 625), # 4 * 4 * 128
            nn.BatchNorm1d(625),
            nn.ReLU(),
            nn.Linear(625, 22)
        )        
    
    def forward(self, x):
        out = self.layer1(x)
        out = out.view(out.size(0), -1)
        return self.layer2(out)

In [None]:
model = ConvNet()
model.load_state_dict(torch.load('./state_dicts/classifier.w'))
model.eval()

In [None]:
pic_width = 64

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
eye_labels = [
    'aqua',
    'black',
    'blue',
    'brown',
    'green',
    'orange',
    'pink',
    'purple',
    'red',
    'yellow'
]

hair_labels = [
    'aqua',
    'black',
    'blonde',
    'blue',
    'brown',
    'gray',
    'green',
    'orange',
    'pink',
    'purple',
    'red',
    'white'
]

In [None]:
def color_transform(x):
    x = F.adjust_saturation(x, 2.5)
    x = F.adjust_gamma(x, 0.7)
    x = F.adjust_contrast(x, 1.2)
    return x

transform = transforms.Compose([
        transforms.Resize((pic_width, pic_width)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [None]:
img = Image.open('Chiaki2.png')
img = color_transform(img)
dimg = transform(img)

batch = []

for i in range(32):
    batch.append(dimg)

batch = tuple(batch)

X = torch.stack(batch)

Y_pred = model(X.to(device))
Y_pred_eye = Y_pred[:, :10]
Y_pred_hair = Y_pred[:, 10:]
    
Y_pred_eye_idx = Y_pred_eye.argmax(dim=1)
Y_pred_hair_idx = Y_pred_hair.argmax(dim=1)

plt.figure(figsize=(5, 5))
plt.imshow(X.cpu()[0].permute(1, 2, 0))
pred_eye_col = eye_labels[Y_pred_eye_idx.cpu()[0]]
pred_hair_col = hair_labels[Y_pred_hair_idx.cpu()[0]]
plt.title(f'I think the eyes are {pred_eye_col} and the hair is {pred_hair_col}')
plt.show()