In [1]:
import sys
import torch
import torchvision

# Prepare Data

## Loading the dataset


### Custom dataset



In [2]:
import csv
import os
import numpy as np
from PIL import Image
import torch

class MTLData(torch.utils.data.Dataset):
    def __init__(self, csv_file, mode='train', transform=None):
        self.mode = mode # 'train', 'val' or 'test'
        self.data_list = []
        self.race_labels = [] 
        self.gender_labels = []
        self.age_labels = []
        self.transform = transform
        
        with open(csv_file, newline='') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                self.data_list.append(row['img_name'])
                if mode != 'test':
                    self.race_labels.append(row['race'])
                    self.gender_labels.append(row['gender'])
                    self.age_labels.append(row['age'])

    def __getitem__(self, index):
        max_age = 116
        data = Image.open(self.data_list[index])
        if self.transform is not None:
            data = self.transform(data)
        if self.mode == 'test':
            return data
        race_label = torch.tensor(int(self.race_labels[index]))
        gender_label = torch.tensor(int(self.gender_labels[index]))
        age_label = torch.tensor(int(self.age_labels[index])/max_age)

        return data, race_label, gender_label, age_label

    def __len__(self):
        return len(self.data_list)

### Data augmentation 



In [6]:
from torchvision import transforms, models
# For TRAIN

transforms_train = transforms.Compose(
    [transforms.Resize((196, 196)),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomVerticalFlip(p=0.5),
     #transforms.RandomCrop(224, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
    ]
)

# For VAL, TEST

transforms_test = transforms.Compose(
    [transforms.Resize((196, 196)),
#    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
    ]
)

### Instantiate dataset


In [7]:
dataset_train = MTLData('/home/ctku/Code/Course/DL/Final_Project/train.csv', mode='train', transform=transforms_train)
dataset_val = MTLData('/home/ctku/Code/Course/DL/Final_Project/val.csv', mode='val', transform=transforms_test)
dataset_test = MTLData('/home/ctku/Code/Course/DL/Final_Project/test.csv', mode='test', transform=transforms_test)

In [8]:
print("The first image's shape in dataset_train :", dataset_train.__getitem__(0)[0].size())
print("There are", dataset_train.__len__(), "images in dataset_train.")

The first image's shape in dataset_train : torch.Size([3, 196, 196])
There are 14223 images in dataset_train.


In [9]:
print(dataset_train.__getitem__(10)[1])
print(dataset_train.__getitem__(10)[2])

tensor(0)
tensor(0)


### `DataLoader`


In [10]:
from torch.utils.data import DataLoader

batch_size = 64
num_workers = 2
train_loader = DataLoader(dataset_train, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(dataset_val, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=num_workers, shuffle=False)

# Implement MAE using PyTorch

### Define a Masked Autoencoder



In [11]:
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import math

class MAE(nn.Module):
    def __init__(self, image_channel, image_size, patch_size, enc_dim, dec_dim, encoder, decoder, mask_ratio=0.75) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.patch_dim = patch_size * patch_size * image_channel
        self.token_num = (image_size//patch_size)**2
        # Note that the input to the torch.nn.Transformer have the batch dimension in the middle: [T(token), B(batch), D(feature)]
        
        self.shuffler = PatchShuffler(mask_ratio, self.token_num)
        
        self.register_buffer('enc_pos', positional_encoding(enc_dim, max_len=self.token_num))
        self.register_buffer('dec_pos', positional_encoding(dec_dim, max_len=self.token_num))
        
        self.mask_emb = nn.Parameter(torch.randn(dec_dim))
        
        self.in_proj = nn.Linear(self.patch_dim, enc_dim)
        self.encoder = Transformer(d_model=enc_dim, **encoder)
        self.mid_proj = nn.Linear(enc_dim, dec_dim) if enc_dim != dec_dim else nn.Identity()
        self.decoder = Transformer(d_model=dec_dim, **decoder)
        self.out_proj = nn.Linear(dec_dim, self.patch_dim) if dec_dim != self.patch_dim else nn.Identity()
        
    def forward(self, img, viz=False):
        
        self.shuffler.init_rand_idx(img.shape[0], img.device)
        
        
        patches = rearrange(img, 'b c (h s1) (w s2) -> (h w) b (s1 s2 c)', s1=self.patch_size, s2=self.patch_size)
        
        emb = self.in_proj(patches)

        _, enc_inp = self.shuffler.shuffle_split(emb + self.enc_pos)
        
        
        x = self.encoder(enc_inp)
        x = self.mid_proj(x)

        x = torch.cat([x, self.mask_emb.expand(self.token_num-x.shape[0], x.shape[1], -1)])
        dec_pos = self.shuffler.shuffle(self.dec_pos.expand_as(x))
        dec_out = self.decoder(x + dec_pos)

        pixel_recon = self.out_proj(dec_out)
        
        inpainted_patches, _ = self.shuffler.split(pixel_recon)
        
        # get target from input patches
        masked_patches, _ = self.shuffler.shuffle_split(patches)
        
        loss = F.mse_loss(inpainted_patches, masked_patches)
        
        if viz:
            img_recon = self.shuffler.unshuffle(pixel_recon)
            img_recon = rearrange('(h w) b (s1 s2 c) -> b c (h s1) (w s2)', h=img.shape[2]//self.patch_size, s1=self.patch_size, s2=self.patch_size)
            return {'loss':loss, 'recon': img_recon}
        
        return {'loss':loss}
    
class PatchShuffler(nn.Module):
    def __init__(self, ratio=0.75, token_num=196):
        super().__init__()
        self.mask_n = int(ratio*token_num)
        self.token_n = token_num
        
    def init_rand_idx(self, batch_size, device) -> None:
        self.rand_idx = torch.rand(self.token_n, batch_size, device=device).argsort(dim=0)
        self.sort_idx = torch.argsort(self.rand_idx, dim=0).to(device)
        
    def shuffle(self, x):
        return x.gather(0, self.rand_idx.unsqueeze(-1).expand_as(x))
    
    def unshuffle(self, x):
        return x.gather(0, self.sort_idx.unsqueeze(-1).expand_as(x))
    
    def shuffle_split(self, x):
        x = self.shuffle(x)
        return x.split(self.mask_n)

    def split(self, x):
        return x.split(self.mask_n)
    
def positional_encoding(d_model, max_len=5000):
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, 1, d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    return pe

class Transformer(nn.Module):
    def __init__(self, num_layers, norm, d_model, **layer_kwargs):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, **layer_kwargs),
            num_layers=num_layers,
            norm=norm
        )
    def forward(self, *args, **kwargs):
        return self.transformer(*args, **kwargs)


In [12]:
import torch.nn as nn 
import torch.nn.functional as F

class MTL(nn.Module): 
    def __init__(self): 
        super(MTL, self).__init__()
        
#        ************************************
        
#        此段將 MAE output 串接MTL的 code 
        
#        ************************************
        self.rn = models.resnet34(pretrained=True)
        num_ftrs = self.rn.fc.in_features
        
        self.rn.fc = nn.Linear(num_ftrs, num_ftrs)
        self.rn.fc1 = nn.Linear(num_ftrs, 5)
        self.rn.fc2 = nn.Linear(num_ftrs, 2)
        self.rn.fc3 = nn.Linear(num_ftrs, 1)
        # self.dropout = nn.Dropout(0.5)


    def forward(self, x): 
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x)
        
        out = F.relu(self.rn(x))
#         out = self.rn18(x)
#         out = self.rn18.fc(x)
#         out1 = self.classifier1(out)
#         out2 = self.classifier2(out)
        #  out = self.dropout(out)
        # out = self.fc3(out)
        
        #race
        out1 = self.rn.fc1(out)
        #gender
        out2 = self.rn.fc2(out)
#         out2 = torch.sigmoid(out2)
#         out2 = out2.view(-1)
        #age
        out3 = self.rn.fc3(out)
        out3 = torch.sigmoid(out3)
        out3 = out3.view(-1)

        return out1, out2, out3

In [13]:
class MAEtoMTL(nn.Module):
    def __init__(self, image_channel, image_size, patch_size, enc_dim, dec_dim, encoder, decoder, mask_ratio=0.75):
        super(MAEtoMTL, self).__init__()

        self.mae = MAE(image_channel, image_size, patch_size, enc_dim, dec_dim, encoder, decoder, mask_ratio=0.75)
        self.mtl = MTL()
        
    def forward(self, x):
        
        output =  self.mae(x)
        final_outputs = self.mtl(output)
        
        return final_outputs

In [14]:
device = torch.device('cuda:0')

model = MAEtoMTL(
        image_size=196,
        image_channel=3,
        patch_size=16,
        enc_dim=512,
        dec_dim=256,
        encoder=dict(
            num_layers=12,
            norm=None,
            nhead=8,
            dim_feedforward=2048,
            dropout=0,
            activation='relu'),
        decoder=dict(
            num_layers=12,
            norm=None,
                nhead=4,
                dim_feedforward=1024,
                dropout=0,
                activation='relu'),
        mask_ratio=0.75)

model = model.to(device)
model = model.eval()
print(model)

MAEtoMTL(
  (mae): MAE(
    (shuffler): PatchShuffler()
    (in_proj): Linear(in_features=768, out_features=512, bias=True)
    (encoder): Transformer(
      (transformer): TransformerEncoder(
        (layers): ModuleList(
          (0): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
            )
            (linear1): Linear(in_features=512, out_features=2048, bias=True)
            (dropout): Dropout(p=0, inplace=False)
            (linear2): Linear(in_features=2048, out_features=512, bias=True)
            (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0, inplace=False)
            (dropout2): Dropout(p=0, inplace=False)
          )
          (1): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out

### Define loss and optimizer

In [15]:
import torch.nn as nn
import torch.optim as optim

criterion1 = nn.CrossEntropyLoss()
# criterion2 = nn.BCELoss()
criterion2 = nn.CrossEntropyLoss()
criterion3 = nn.MSELoss()
# criterion2 = nn.MultiLabelSoftMarginLoss()
# criterion2 = nn.CrossEntropyLoss() 
# optimizer = optim.Adam(model.parameters(), lr = 0.0001)
# optimizer = optim.SGD(model.parameters(), lr = 0.0001)
optimizer = optim.AdamW(model.parameters(), lr = 0.0001)

criterion1 = criterion1.to(device)
criterion2 = criterion2.to(device)
criterion3 = criterion3.to(device)

In [16]:
max_epochs = 3

In [17]:
def lr_lambda(max_epochs):
        if epoch < args.epochs_warmup:
            p = epoch / args.epochs_warmup
            lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
        else:
            eta_min = args.lr * (args.lr_decay_rate ** 3)
            lr = eta_min + (args.lr - eta_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2
        return lr

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

### Train the model

#### Train functionss and total accuracy.

Hint: [torch.max()](https://pytorch.org/docs/stable/generated/torch.max.html#torch-max) or [torch.argmax()](https://pytorch.org/docs/stable/generated/torch.argmax.html)

In [18]:
from sklearn.metrics import f1_score

def train(input_data, model, criterion1, criterion2, criterion3, optimizer):
    '''
    Argement:
    input_data -- iterable data, typr torch.utils.data.Dataloader is prefer
    model -- nn.Module, model contain forward to predict output
    criterion -- loss function, used to evaluate goodness of model
    optimizer -- optmizer function, method for weight updating
    '''
    model.train()
    loss_list = []
#     f1 = 0.0
    total_count = 0
    acc_count_1 = 0
    acc_count_2 = 0
    acc_count_3 = 0
    for i in input_data:
        images = i[0].to(device)
        cate = i[1].to(device)
        attr = i[2].to(device)
        age = i[3].to(device)
        
        optimizer.zero_grad()

        outputs1, outputs2, outputs3 = model(images)

        loss1 = criterion1(outputs1, cate)
        loss2 = criterion2(outputs2, attr)
        loss3 = criterion3(outputs3, age.float())
        loss = loss1 + loss2 + loss3
        loss.backward()
        optimizer.step()
        
        scheduler.step()
        
        loss_list.append(loss.item())

        _, predicted_ca = torch.max(outputs1.data, 1)
        _, predicted_at = torch.max(outputs2.data, 1)
        total_count += cate.size(0)
        #binary using acc
        acc_count_1 += (predicted_ca == cate).sum().item()

        #cate using acc
#         _, predicted_at = torch.max(outputs2.data, 1)
#         acc_count_2 += (predicted_at == attr).sum().item()
        acc_count_2 += (predicted_at == attr).sum().item()
        
        acc_count_3 += ((outputs3 > 0.5) == age).sum().item()
        task1_loss = loss1
        task2_loss = loss2
        task3_loss = loss3

        
#         predicted_at = (torch.sigmoid(outputs2) > 0.5).int()
        #attr using F1

#         f1+= f1_score(attr.int().to('cpu').numpy(),predicted_at.to('cpu').numpy(),average='samples')
        

    # Compute this epoch accuracy and loss
    acc_1 = acc_count_1 / total_count
    acc_2 = acc_count_2 / total_count
    acc_3 = acc_count_3 / total_count
#     F1 = f1 / len(loss_list) 
    loss = sum(loss_list) / len(loss_list)
    return acc_1, acc_2, acc_3, task1_loss, task2_loss, task3_loss, loss

#### Validate function

In [19]:
def val(input_data, model, criterion1, criterion2, criterion3):
    model.eval()
    
    loss_list = []
    total_count = 0
    acc_count_1 = 0
    acc_count_2 = 0
    acc_count_3 = 0
#     f1 = 0.0
    with torch.no_grad():
        for i in input_data:
            images = i[0].to(device)
            cate = i[1].to(device)
            attr = i[2].to(device)
            age = i[3].to(device)

            outputs1, outputs2, outputs3 = model(images)
            loss1 = criterion1(outputs1, cate)
            loss2 = criterion2(outputs2, attr)
            loss3 = criterion3(outputs3, age.float())
            loss = loss1 + loss2 + loss3
            
            loss_list.append(loss.item())
            
            _, predicted_cate = torch.max(outputs1.data, 1)
            _, predicted_at = torch.max(outputs2.data, 1)
            total_count += cate.size(0)
            acc_count_1 += (predicted_cate == cate).sum().item()

            #cate using acc
#             _, predicted_at = torch.max(outputs2.data, 1)
#             acc_count_2 += (predicted_at == attr).sum().item()
            acc_count_2 += (predicted_at == attr).sum().item()
            task1_loss = loss1
            task2_loss = loss2
            task3_loss = loss3
            acc_count_3 += ((outputs3 > 0.5) == age).sum().item()
#             predicted_attr = (torch.sigmoid(outputs2) > 0.5).int()
#             f1+= f1_score(attr.int().to('cpu').numpy(),predicted_attr.to('cpu').numpy(),average='samples')

            
    acc_1 = acc_count_1 / total_count
    acc_2 = acc_count_2 / total_count
    acc_3 = acc_count_3 / total_count
#     F1 = f1 / len(loss_list)
    loss = sum(loss_list) / len(loss_list)
    return acc_1, acc_2, acc_3, task1_loss, task2_loss, task3_loss, loss

#### Training in a loop

In [None]:
# max_epochs = 10
log_interval = 1 # print acc and loss in per log_interval time

train_acc1_list = []
train_acc2_list = []
train_acc3_list = []
train_loss1_list = []
train_loss2_list = []
train_loss3_list = []
# train_f1_list = []
train_loss_list = []
val_acc1_list = []
val_acc2_list = []
val_acc3_list = []
val_loss1_list = []
val_loss2_list = []
val_loss3_list = []
# val_f1_list = []
val_loss_list = []      
        
for epoch in range(1, max_epochs + 1):
    train_acc1, train_acc2, train_acc3, train_loss1, train_loss2, train_loss3, train_loss = train(train_loader, model, criterion1, criterion2, criterion3, optimizer)
    val_acc1, val_acc2, val_acc3, val_loss1, val_loss2, val_loss3, val_loss = val(val_loader, model, criterion1, criterion2, criterion3)
    train_acc1_list.append(train_acc1)
    train_acc2_list.append(train_acc2)
    train_acc3_list.append(train_acc3)
    train_loss1_list.append(train_loss1)
    train_loss2_list.append(train_loss2)
    train_loss3_list.append(train_loss3)
    train_loss_list.append(train_loss)
    
    val_acc1_list.append(val_acc1)
    val_acc2_list.append(val_acc2)
    val_acc3_list.append(val_acc3)
    val_loss1_list.append(val_loss1)
    val_loss2_list.append(val_loss2)
    val_loss3_list.append(val_loss3)
    val_loss_list.append(val_loss)
    if epoch % log_interval == 0:          
        print('=' * 40, 'Epoch', epoch, '=' * 40)     
        print('task1 Train Acc: {:.6f} task2 Train Acc: {:.6f} task3 Train Acc: {:.6f} task1 Train Loss: {:.6f} task2 Train Loss: {:.6f} task3 Train Loss: {:.6f}  Train Loss: {:.6f}'.format(train_acc1, train_acc2, train_acc3, train_loss1, train_loss2, train_loss3, train_loss)) 
        print('task1   Val Acc: {:.6f} task2   Val Acc: {:.6f} task3   Val Acc: {:.6f} task1   Val Loss: {:.6f} task2   Val Loss: {:.6f} task3   Val Loss: {:.6f}    Val Loss: {:.6f}'.format(val_acc1, val_acc2, val_acc3, val_loss1, val_loss2, val_loss3, val_loss))          

In [None]:
# save your well-trained state_dict of model          
torch.save(model.state_dict(), 'MAEtoMTL.pt')   

#### Visualize accuracy and loss

In [None]:
import matplotlib.pyplot as plt


plt.figure(figsize=(12, 4))
plt.plot(range(len(train_loss_list)), train_loss_list)
plt.plot(range(len(val_loss_list)), val_loss_list, c='r')
plt.legend(['train', 'val'])
plt.title('Loss')
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(range(len(train_acc1_list)), train_acc1_list)
plt.plot(range(len(val_acc1_list)), val_acc1_list, c='r')
plt.legend(['train', 'val'])
plt.title('Acc')
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(range(len(train_acc2_list)), train_acc2_list)
plt.plot(range(len(val_acc2_list)), val_acc2_list, c='r')
plt.legend(['train', 'val'])
plt.title('Acc')
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(range(len(train_acc3_list)), train_acc3_list)
plt.plot(range(len(val_acc3_list)), val_acc3_list, c='r')
plt.legend(['train', 'val'])
plt.title('Acc')
plt.show()

train_loss1_list = np.array(torch.tensor(train_loss1_list, device='cpu'))
val_loss1_list = np.array(torch.tensor(val_loss1_list, device='cpu'))
plt.figure(figsize=(12, 4))
plt.plot(range(len(train_loss1_list)), train_loss1_list)
plt.plot(range(len(val_loss1_list)), val_loss1_list, c='r')
plt.legend(['train', 'val'])
plt.title('Task1 Loss')
plt.show()

train_loss2_list = np.array(torch.tensor(train_loss2_list, device='cpu'))
val_loss2_list = np.array(torch.tensor(val_loss2_list, device='cpu'))
plt.figure(figsize=(12, 4))
plt.plot(range(len(train_loss2_list)), train_loss2_list)
plt.plot(range(len(val_loss2_list)), val_loss2_list, c='r')
plt.legend(['train', 'val'])
plt.title('Task2 Loss')
plt.show()

train_loss3_list = np.array(torch.tensor(train_loss3_list, device='cpu'))
val_loss3_list = np.array(torch.tensor(val_loss3_list, device='cpu'))
plt.figure(figsize=(12, 4))
plt.plot(range(len(train_loss3_list)), train_loss3_list)
plt.plot(range(len(val_loss3_list)), val_loss3_list, c='r')
plt.legend(['train', 'val'])
plt.title('Task3 Loss')
plt.show()

### Predict Result



In [None]:
# # if you wanna load previous best model
# ckpt = torch.load('MTL.pt')
# model.load_state_dict(ckpt) 

In [None]:
def predict(input_data, model):
    model.eval()
    cate_list = []
    attr_list = []
    age_list = []
    with torch.no_grad():
        for images in input_data:
            max_age = 116
            images = images.to(device)
            output1, output2, output3 = model(images)
            _, predicted_cate = torch.max(output1.data, 1)
            _, predicted_attr = torch.max(output2.data, 1)
#             predicted_attr = torch.sigmoid(output2) > 0.5
            predicted_age = output3 * max_age
            cate_list.extend(predicted_cate.to('cpu').numpy().tolist())
            attr_list.extend(predicted_attr.to('cpu').numpy().tolist())
            age_list.extend(predicted_age.to('cpu').numpy().tolist())
            
    return cate_list, attr_list, age_list 

In [None]:
data_folder = 'UTKFace_csv'
cate_csv, attr_csv, age_csv = predict(test_loader, model)
with open('UTKFace_csv/result_race.csv', 'w', newline='') as csvFile:
    writer = csv.DictWriter(csvFile, fieldnames=['file_path', 'race_label'])
    writer.writeheader()
    idx = 0
    for result in cate_csv:
        file_path = dataset_test.data_list[idx].replace(data_folder + '/', '')
        file_path = data_folder + '/' + file_path
        writer.writerow({'file_path':file_path, 'race_label':result})
        idx += 1

with open('UTKFace_csv/result_gender.csv', 'w', newline='') as csvFile:
    writer = csv.DictWriter(csvFile, fieldnames=['file_path', 'gender_label'])
    writer.writeheader()
    idx = 0
    for result in attr_csv:
        file_path = dataset_test.data_list[idx].replace(data_folder + '/', '')
        file_path = data_folder + '/' + file_path
        writer.writerow({'file_path':file_path, 'gender_label':result})
        idx += 1
        

with open('UTKFace_csv/result_age.csv', 'w', newline='') as csvFile:
    writer = csv.DictWriter(csvFile, fieldnames=['file_path', 'age_label'])
    writer.writeheader()
    idx = 0
    for result in age_csv:
        file_path = dataset_test.data_list[idx].replace(data_folder + '/', '')
        file_path = data_folder + '/' + file_path
        writer.writerow({'file_path':file_path, 'age_label':result})
        idx += 1

In [None]:
from PIL import Image
img = Image.open("UTKFace/4_1_1_20170112210910341.jpg.chip.jpg")
(w, h) = img.size
print('w=%d, h=%d' % (w, h))
img.show()

In [None]:
import pandas as pd
from PIL import Image
result1 = pd.read_csv('result_age.csv')
origin1 = pd.read_csv('test.csv')

total = pd.concat([result1, origin1], axis=1, ignore_index=True)
# print(total.columns)
# total[1]
total = total[(total[1] - total[3]) > 10]

# total = total.to_frame()


for i, j in zip(total[0][:11], total[1][:11]):
        print(i)
        print(j)
        img = Image.open(i[12:])
        (w, h) = img.size
        
        img.show()
        
        

In [None]:
img = Image.open("UTKFace/25_0_0_20170117140540912.jpg.chip.jpg")

img

In [None]:
img_trans = transforms_test(img)
type(img_trans)
img_trans.shape

In [None]:
import torch

model =  MTL_Model()
ckpt = torch.load('MTL.pt')
model.load_state_dict(ckpt) 

In [None]:
race, gender, age = model(img_trans.unsqueeze(dim=0))

In [None]:
race # white, balck, asian, india, others

In [None]:
race_p = torch.softmax(race, dim=1)
race_p

In [None]:
gender

In [None]:
age.item() * 116