In [None]:
from datasets import LSPDataset, LSPExtendedDataset
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch

from vis import show_pose, show_pose_from_hmap, show_hmap
from models import PoseNet

# Model 1

In [None]:
# lspdata = LSPDataset('./dataset/lsp_dataset/', 'points')
lspdata = LSPExtendedDataset('./dataset/lspet_dataset/')

In [None]:
a, b, c = lspdata.__getitem__(np.random.randint(len(lspdata)))
show_pose(a, b, c)

In [None]:
dataloader = data.DataLoader(lspdata, 10, shuffle = True)

In [None]:
class MSE_Loss(nn.Module):
    
    def __init__(self):
        super(MSE_Loss, self).__init__()
        
    def forward(self, output, joints, mask):
        mask = torch.cat((ignore_joints.unsqueeze(2), ignore_joints.unsqueeze(2)), axis=2).reshape((-1, 28))
        return torch.sum(mask * (output - joints)**2)

In [None]:
net = PoseNet().cuda()
net.load_state_dict(torch.load('./weights/simple_pose.weights'))
criterion = MSE_Loss().cuda()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)

In [None]:
for epoch in range(50):
    for batch, batch_data in enumerate(dataloader):
        net.zero_grad()
        im, joints, ignore_joints = batch_data
        im = im.cuda()
        joints = joints.cuda()
        ignore_joints = ignore_joints.cuda()
        output = net(im)
        loss = criterion(output, joints, ignore_joints)
        loss.backward() 
        optimizer.step()
        if (batch+1) % 10 == 0:
            print('Batch: {}, Loss: {}, Epoch: {}'.format(batch + 1, loss.data, epoch))
    if (epoch + 1) % 10 == 0:
        print('Saving weights for epoch: {}'.format(epoch))
#         torch.save(net.state_dict(), './weights/simple_pose{}.weights'.format(epoch))

In [None]:
# torch.save(net.state_dict(), './weights/simple_pose.weights'.format(epoch))
net.eval()
a, _, c = lspdata.__getitem__(np.random.randint(len(lspdata)))
b = net(a.unsqueeze(0).cuda())
show_pose(a, b)

In [None]:
base_dir = './dataset/lsp_dataset/images/'
# base_dir = './test/'
t = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ])
test_image = os.listdir(base_dir)
image = Image.open(base_dir + test_image[np.random.randint(len(test_image))])
show_pose(t(image) ,net(t(image).unsqueeze(0).cuda()))


# Model 2

In [None]:
# lspdata = LSPDataset('./dataset/lsp_dataset/', 'points')
lspdata = LSPExtendedDataset('./dataset/lspet_dataset/', 'heatmap')

In [None]:
a, b, c = lspdata.__getitem__(np.random.randint(len(lspdata)))
show_pose_from_hmap(a, b, c)

In [None]:
dataloader = data.DataLoader(lspdata, 10, shuffle = True)

In [None]:
class JointsMSELoss(nn.Module):
    def __init__(self):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, output, target, mask):
        pred = output.reshape(10, 14, -1)
        tar = target.reshape(10, 14, -1)
#         mask = mask.reshape(10, 14)
        loss = 0
        for i in range(len(pred)):
            loss +=  self.criterion(pred[i], tar[i])
        print(loss.shape)
        return loss

In [None]:
net = HeatmapPose().cuda()
net.load_state_dict(torch.load('./heatmap_pose.weights'))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.00001)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr = 0.0000001)

In [None]:
for epoch in range(100):
    for batch, batch_data in enumerate(dataloader):
        net.zero_grad()
        im, joints, _ = batch_data
        im = im.cuda()
        joints = joints.cuda()
        output = net(im)
        loss = criterion(output, joints)
        loss.backward()
        optimizer.step()
        if batch % 10 == 0:
            print('Batch: {}, Loss: {}, Epoch: {}'.format(batch, loss.data, epoch))
    torch.save(net.state_dict(), './heatmap_pose.weights')

In [None]:
a,c, m = lspdata.__getitem__(np.random.randint(len(lspdata)))
h = net(a.unsqueeze(0).cuda()).squeeze().cpu().detach()
show_hmap(a, h)
show_pose_from_hmap(a, h)

In [None]:
a,d = lspdata.__getitem__(np.random.randint(len(lspdata)))
a = c.unsqueeze(0)
d = d.unsqueeze(0)
criterion(a, d)

In [None]:
base_dir = './New folder/'
# base_dir = './dataset/FLIC/images/' 
# base_dir = './dataset/lsp_dataset/images/'
t = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
test_image = os.listdir(base_dir)
image = Image.open(base_dir + test_image[np.random.randint(len(test_image))])
image = t(image)
h = net(image.unsqueeze(0).cuda()).squeeze().cpu().detach()
show_pose_from_hmap(image, h)
show_hmap(image, h)


# Model 3

In [None]:
class HeatmapPose(nn.Module):
    
    def __init__(self):
        super(HeatmapPose, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, 5, padding = 2)
        self.conv2 = nn.Conv2d(128, 128, 5, padding = 2)
        self.conv3 = nn.Conv2d(128, 128, 5, padding = 2)
        self.conv33 = nn.Conv2d(128, 512, 9, padding = 4)
        
    
    def forward(self, x):
        # x -> [-1, 3, 256, 256]
        x = F.relu(self.conv1(x))
        # x -> [-1, 32, 256, 256]
        x = F.max_pool2d(x, 2)
        # x -> [-1, 32, 128, 128]
        
        x = F.relu(self.conv2(x))
        # x -> [-1, 32, 128, 128]
        x = F.max_pool2d(x, 2)
        # x -> [-1, 32, 64, 64]
        
        x = F.relu(self.conv3(x))
        # x -> [-1, 64, 64, 64]
        x = F.max_pool2d(x, 2)
        # x -> [-1, 64, 32, 32]
        x = F.relu(self.conv33(x))
        
        return x
    
class HeatmapPoseA(nn.Module):
    
    def __init__(self):
        super(HeatmapPoseA, self).__init__()
        self.hmap_pose = HeatmapPose()
        
    def forward(self, x):
        a = self.hmap_pose(x)
        x = F.interpolate(x, 64)
        b = self.hmap_pose(x)
        x = F.interpolate(x, 32)
        c = self.hmap_pose(x)
        x = torch.cat((a, b, c), axis=0)
        return x