In [1]:
import os
import numpy as np
import pandas as pd
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

In [2]:
images = []
ages = []
for image in os.listdir("../input/utkface-new/crop_part1"):
    age = int(image.split("_")[0])
    ages.append(age)
    img = cv2.imread(f"../input/utkface-new/crop_part1/{image}")
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    images.append(img)

In [3]:
images = pd.Series(images,name="Images")
ages = pd.Series(ages,name="Ages")
df = pd.concat([images,ages],axis=1)
under_4 = df[df["Ages"]<=4]
under_4_new = under_4.sample(frac=0.3)
up_4 = df[df["Ages"]>4]
df = pd.concat([under_4_new,up_4],axis=0)
df = df[df["Ages"]<90]

In [4]:
X = np.array(df["Images"].tolist())
Y = np.array(df["Ages"].tolist())

In [5]:
X = X.reshape((-1, 3, X.shape[1], X.shape[2]))
Y = Y.reshape(Y.shape[0], 1)
X = torch.tensor(X)
Y = torch.tensor(Y)

In [6]:
transform = torchvision.transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [7]:
class CustomDataset(Dataset):
    def __init__(self,image,label,transform):
        super().__init__()
        self.image = image
        self.label = label
        self.transform = transform
        
    def __getitem__(self,index):
        label = self.label[index]
        image = self.image[index]
        image = self.transform(image)
        
        return image, label
    
    def __len__(self):
        return len(self.label)

In [8]:
dataset = CustomDataset(X,Y,transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset,[train_size,test_size])
train = torch.utils.data.DataLoader(train_dataset,batch_size=32,shuffle=True)
test = torch.utils.data.DataLoader(test_dataset,batch_size=32,shuffle=False)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet152(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth


  0%|          | 0.00/230M [00:00<?, ?B/s]

In [10]:
in_features = model.fc.in_features
model.fc = nn.Linear(in_features,1)
model = model.to(device)
ct = 0
for child in model.children():
    ct +=1
    if ct<7:
        for param in child.parameters():
            param.requires_grad = False

In [11]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_fuction = nn.MSELoss()

In [12]:
for epoch in range(25):
    train_loss = 0.0
    for image, label in tqdm(train):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        predict = model(image.float())
        loss = loss_fuction(predict,label.float())
        
        loss.backward()
        optimizer.step()
        train_loss += loss
    
    total_loss = train_loss / len(train)
    print(f"Epochs:{epoch+1}, Loss:{total_loss}")

100%|██████████| 205/205 [01:24<00:00,  2.44it/s]


Epochs:1, Loss:371.1922607421875


100%|██████████| 205/205 [01:18<00:00,  2.60it/s]


Epochs:2, Loss:263.6663818359375


100%|██████████| 205/205 [01:17<00:00,  2.63it/s]


Epochs:3, Loss:235.28988647460938


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:4, Loss:208.91554260253906


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:5, Loss:191.36643981933594


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:6, Loss:167.08139038085938


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:7, Loss:137.16204833984375


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:8, Loss:104.37530517578125


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:9, Loss:80.8545913696289


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:10, Loss:64.69515228271484


100%|██████████| 205/205 [01:18<00:00,  2.62it/s]


Epochs:11, Loss:51.39288330078125


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:12, Loss:46.36651611328125


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:13, Loss:43.68232727050781


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:14, Loss:34.683128356933594


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:15, Loss:30.784748077392578


100%|██████████| 205/205 [01:19<00:00,  2.57it/s]


Epochs:16, Loss:30.316654205322266


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:17, Loss:31.579126358032227


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]


Epochs:18, Loss:34.34281539916992


100%|██████████| 205/205 [01:17<00:00,  2.63it/s]


Epochs:19, Loss:33.765926361083984


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:20, Loss:29.122690200805664


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:21, Loss:29.277246475219727


100%|██████████| 205/205 [01:17<00:00,  2.63it/s]


Epochs:22, Loss:23.71278190612793


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:23, Loss:24.83338165283203


100%|██████████| 205/205 [01:17<00:00,  2.65it/s]


Epochs:24, Loss:25.394790649414062


100%|██████████| 205/205 [01:17<00:00,  2.64it/s]

Epochs:25, Loss:25.86041259765625





In [13]:
torch.save(model.state_dict(),"weights.pth")