This file should be placed inside detectron2/projects/DensePose/ for correct imports. 
Link to the densepose2 framework: 
https://github.com/facebookresearch/detectron2

In [1]:
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.modeling import build_model
from densepose import add_densepose_config, add_hrnet_config
from detectron2.checkpoint import DetectionCheckpointer

In [2]:
# Load densepose model with learned weights
cfg = get_cfg()
add_densepose_config(cfg)
add_hrnet_config(cfg)

cfg.merge_from_file("configs/densepose_rcnn_R_50_FPN_s1x.yaml")

cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl"
cfg.freeze()
model = build_model(cfg)
model.eval()
checkpointer = DetectionCheckpointer(model)
checkpointer.load(cfg.MODEL.WEIGHTS)

{'__author__': 'Detectron2 Model Zoo'}

In [3]:
type(model)

detectron2.modeling.meta_arch.rcnn.GeneralizedRCNN

In [4]:
model

GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res2): Sequential(
        (0): BottleneckBlock

In [5]:
import detectron2
b = model.backbone
b
type(b)
# ??detectron2.modeling.backbone.fpn.FPN

detectron2.modeling.backbone.fpn.FPN

In [6]:
# print(str(cfg))

In [7]:
import torch
from torch.nn import AdaptiveAvgPool2d
from torch.nn import Linear, Module
from torch.nn.functional import softmax
from torch.nn.functional import relu
from torch.nn import BatchNorm1d

In [8]:
# Check what the backbone from the original model returns
img = torch.zeros(size=(1, 3, 480, 640)).cuda()
out = b(img)
for k, v in out.items():
    print(k, v.size())

# Get the output with the smallest resolution and squash it to 2d    
a = AdaptiveAvgPool2d((1, 1))
squashed = a(out['p6'])
print(squashed.size())
reshaped = squashed.view((-1, 256))
print(reshaped.size())

p2 torch.Size([1, 256, 120, 160])
p3 torch.Size([1, 256, 60, 80])
p4 torch.Size([1, 256, 30, 40])
p5 torch.Size([1, 256, 15, 20])
p6 torch.Size([1, 256, 8, 10])
torch.Size([1, 256, 1, 1])
torch.Size([1, 256])


In [9]:
# Define yoga model with the backbone from dense pose
class YogaPoseEstimatorModel(Module):
    def __init__(self, backbone, num_classes, pixel_mean, pixel_std):
        super().__init__()
        self.backbone = backbone
        self.avg_pool = AdaptiveAvgPool2d((1, 1))
        self.fc1 = Linear(256, 192)
        self.bn1 = BatchNorm1d(192)
        self.fc2 = Linear(192, 128)
        self.bn2 = BatchNorm1d(128)
        self.fc3 = Linear(128, 64)
        self.bn3 = BatchNorm1d(64)
        self.fc4 = Linear(64, num_classes)
        self.pixel_mean = pixel_mean
        self.pixel_std = pixel_std
    
    def forward(self, x):
        x = self.preprocess(x)
        x = self.backbone(x)['p6']
        x = self.avg_pool(x)
        x = x.view((-1, 256))
        x = self.fc1(x)
        x = self.bn1(x)
        x = relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = relu(x)
        x = self.fc3(x)
        x = self.bn3(x)
        x = relu(x)
        x = self.fc4(x)
        x = softmax(x, dim=1)
        return x
    
    def preprocess(self, tensor):
        # Preprocessing from the source code for the original model
        tensor = (tensor - self.pixel_mean) / self.pixel_std
        return tensor

In [10]:
# We have 10 classes for simple yoga dataset
# pixel_mean and pixel_std parameters from the original model
my_model = YogaPoseEstimatorModel(model.backbone, 82, model.pixel_mean, model.pixel_std)
my_model.cuda()
print(my_model.pixel_mean)
print(my_model.pixel_std)

tensor([[[103.5300]],

        [[116.2800]],

        [[123.6750]]], device='cuda:0')
tensor([[[1.]],

        [[1.]],

        [[1.]]], device='cuda:0')


In [11]:
from detectron2.data.detection_utils import read_image
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader

In [12]:
class YogaDataset(Dataset):
    def __init__(self, image_dir, split, train_fraction=0.8):
        super().__init__()
        
        class_names = os.listdir(image_dir)
        self.all_classes = sorted([c for c in class_names if not c.startswith('.')])
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.all_classes)}
        
        np.random.seed(42)
        
        self.data = []
        for class_name in self.all_classes:
            filenames = os.listdir(os.path.join(image_dir, class_name))
            filenames = [f for f in filenames if f.endswith('.jpg')]
            np.random.shuffle(filenames)
            num_train = int(train_fraction * len(filenames))
            if split == 'train':
                filenames = filenames[:num_train]
            elif split == 'test':
                filenames = filenames[num_train:]
            else:
                raise ValueError('Unknown split: ' + split)
            for filename in filenames:
                self.data.append((os.path.join(image_dir, class_name, filename), class_name))
    
    def __getitem__(self, idx):
        image_path, class_name = self.data[idx]
        
        image = read_image(image_path, format='BGR')
        height, width, _ = image.shape
        transform = detectron2.data.transforms.transform.ResizeTransform(h=height, w=width, new_h=800, new_w=800, interp=2)
        image = transform.apply_image(image)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
        
        class_idx = self.class_to_idx[class_name]
        return image, class_idx
    
    def __len__(self):
        return len(self.data)


In [13]:
train_set = YogaDataset('/home/ubuntu/Yoga_82_Images', 'train')

In [14]:
# # for data cleaning
# for i in range(len(train_set)):
#     try:
#         _ = train_set[i]
#     except Exception as e:
#         print(i, e)

In [15]:
test_set = YogaDataset('/home/ubuntu/Yoga_82_Images', 'test')
print(test_set.data[0])
print(test_set.data[1])

len(test_set)

# # for data cleaning
# for i in range(len(train_set)):
#     try:
#         _ = train_set[i]
#     except Exception as e:
#         print(i, e)

('/home/ubuntu/Yoga_82_Images/Akarna_Dhanurasana/84.jpg', 'Akarna_Dhanurasana')
('/home/ubuntu/Yoga_82_Images/Akarna_Dhanurasana/34.jpg', 'Akarna_Dhanurasana')


4409

In [16]:
print(len(train_set))
print(len(train_set.all_classes))
# print(train_set.all_classes)
# print(train_set.class_to_idx)
print(train_set.data[0])


17486
82
('/home/ubuntu/Yoga_82_Images/Akarna_Dhanurasana/0_15.jpg', 'Akarna_Dhanurasana')


In [17]:
data_loader = DataLoader(train_set, batch_size=24, shuffle=True, pin_memory=True)

Train the model

In [18]:
# Freeze the weights for densepose layers
for name, parameter in my_model.backbone.named_parameters():
    parameter.requires_grad = False

In [19]:
import torch.optim as optim
from torch.nn import CrossEntropyLoss

# Documentation: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
criterion = CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, my_model.parameters()), lr=0.01)

In [20]:
print(len(list(my_model.parameters())))
print(len(list(filter(lambda p: p.requires_grad, model.parameters()))))

75
54


In [21]:
# for k, v in my_model.named_parameters():
#     print(k, v.sum().item(), v.requires_grad)

In [None]:
# Tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

my_model.train()
my_model.backbone.eval()

for epoch in range(4):
    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        x, y = data
        x = x.cuda()
        y = y.cuda()
        # print(x.size(), y.size())
        optimizer.zero_grad()
        out = my_model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if True: #i % 10 == 9:    # print every 10 batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

print('Finished Training')
        
    

[1,     1] loss: 0.441
[1,     2] loss: 0.440
[1,     3] loss: 0.441
[1,     4] loss: 0.440


  "Palette images with Transparency expressed in bytes should be "


[1,     5] loss: 0.440
[1,     6] loss: 0.441
[1,     7] loss: 0.440
[1,     8] loss: 0.440
[1,     9] loss: 0.439
[1,    10] loss: 0.440
[1,    11] loss: 0.441
[1,    12] loss: 0.439
[1,    13] loss: 0.440
[1,    14] loss: 0.440
[1,    15] loss: 0.439
[1,    16] loss: 0.440
[1,    17] loss: 0.440
[1,    18] loss: 0.431
[1,    19] loss: 0.439
[1,    20] loss: 0.438
[1,    21] loss: 0.434
[1,    22] loss: 0.437
[1,    23] loss: 0.433
[1,    24] loss: 0.440
[1,    25] loss: 0.440
[1,    26] loss: 0.439
[1,    27] loss: 0.437
[1,    28] loss: 0.441
[1,    29] loss: 0.438
[1,    30] loss: 0.437
[1,    31] loss: 0.439
[1,    32] loss: 0.432
[1,    33] loss: 0.428
[1,    34] loss: 0.433
[1,    35] loss: 0.442
[1,    36] loss: 0.433
[1,    37] loss: 0.434
[1,    38] loss: 0.437
[1,    39] loss: 0.441
[1,    40] loss: 0.436
[1,    41] loss: 0.432
[1,    42] loss: 0.438
[1,    43] loss: 0.440
[1,    44] loss: 0.433
[1,    45] loss: 0.430




[1,    46] loss: 0.439
[1,    47] loss: 0.437
[1,    48] loss: 0.421
[1,    49] loss: 0.433
[1,    50] loss: 0.437
[1,    51] loss: 0.429
[1,    52] loss: 0.430
[1,    53] loss: 0.434
[1,    54] loss: 0.438
[1,    55] loss: 0.432
[1,    56] loss: 0.432
[1,    57] loss: 0.425
[1,    58] loss: 0.436
[1,    59] loss: 0.433
[1,    60] loss: 0.437
[1,    61] loss: 0.434
[1,    62] loss: 0.434
[1,    63] loss: 0.434
[1,    64] loss: 0.439
[1,    65] loss: 0.434
[1,    66] loss: 0.441
[1,    67] loss: 0.438
[1,    68] loss: 0.436
[1,    69] loss: 0.437
[1,    70] loss: 0.437
[1,    71] loss: 0.436
[1,    72] loss: 0.438
[1,    73] loss: 0.441
[1,    74] loss: 0.435
[1,    75] loss: 0.436
[1,    76] loss: 0.434
[1,    77] loss: 0.430
[1,    78] loss: 0.433
[1,    79] loss: 0.437
[1,    80] loss: 0.435
[1,    81] loss: 0.433
[1,    82] loss: 0.435
[1,    83] loss: 0.437
[1,    84] loss: 0.431
[1,    85] loss: 0.435
[1,    86] loss: 0.431
[1,    87] loss: 0.433
[1,    88] loss: 0.440
[1,    89] 

[1,   403] loss: 0.427
[1,   404] loss: 0.430
[1,   405] loss: 0.410
[1,   406] loss: 0.432
[1,   407] loss: 0.430
[1,   408] loss: 0.425
[1,   409] loss: 0.427
[1,   410] loss: 0.430
[1,   411] loss: 0.434
[1,   412] loss: 0.426
[1,   413] loss: 0.424
[1,   414] loss: 0.411
[1,   415] loss: 0.430
[1,   416] loss: 0.435
[1,   417] loss: 0.427
[1,   418] loss: 0.419
[1,   419] loss: 0.427
[1,   420] loss: 0.434
[1,   421] loss: 0.430
[1,   422] loss: 0.427
[1,   423] loss: 0.431
[1,   424] loss: 0.422
[1,   425] loss: 0.439
[1,   426] loss: 0.425
[1,   427] loss: 0.426
[1,   428] loss: 0.428
[1,   429] loss: 0.423
[1,   430] loss: 0.428
[1,   431] loss: 0.423
[1,   432] loss: 0.428
[1,   433] loss: 0.422
[1,   434] loss: 0.433
[1,   435] loss: 0.431
[1,   436] loss: 0.434
[1,   437] loss: 0.418
[1,   438] loss: 0.429
[1,   439] loss: 0.417
[1,   440] loss: 0.414
[1,   441] loss: 0.425
[1,   442] loss: 0.430
[1,   443] loss: 0.430
[1,   444] loss: 0.423
[1,   445] loss: 0.436
[1,   446] 

[2,    31] loss: 0.423
[2,    32] loss: 0.429
[2,    33] loss: 0.426
[2,    34] loss: 0.404
[2,    35] loss: 0.427
[2,    36] loss: 0.409
[2,    37] loss: 0.430
[2,    38] loss: 0.430
[2,    39] loss: 0.423
[2,    40] loss: 0.420
[2,    41] loss: 0.414
[2,    42] loss: 0.427
[2,    43] loss: 0.426
[2,    44] loss: 0.426
[2,    45] loss: 0.415
[2,    46] loss: 0.424
[2,    47] loss: 0.436
[2,    48] loss: 0.434
[2,    49] loss: 0.432
[2,    50] loss: 0.426
[2,    51] loss: 0.424
[2,    52] loss: 0.416
[2,    53] loss: 0.416
[2,    54] loss: 0.415
[2,    55] loss: 0.442
[2,    56] loss: 0.419
[2,    57] loss: 0.406
[2,    58] loss: 0.423
[2,    59] loss: 0.431
[2,    60] loss: 0.419
[2,    61] loss: 0.414
[2,    62] loss: 0.422
[2,    63] loss: 0.423
[2,    64] loss: 0.409
[2,    65] loss: 0.418
[2,    66] loss: 0.430
[2,    67] loss: 0.420
[2,    68] loss: 0.416
[2,    69] loss: 0.430
[2,    70] loss: 0.418
[2,    71] loss: 0.425
[2,    72] loss: 0.434
[2,    73] loss: 0.430
[2,    74] 

[2,   388] loss: 0.430
[2,   389] loss: 0.427
[2,   390] loss: 0.433
[2,   391] loss: 0.422
[2,   392] loss: 0.434
[2,   393] loss: 0.418
[2,   394] loss: 0.425
[2,   395] loss: 0.426
[2,   396] loss: 0.414
[2,   397] loss: 0.423
[2,   398] loss: 0.425
[2,   399] loss: 0.417
[2,   400] loss: 0.423
[2,   401] loss: 0.428
[2,   402] loss: 0.422
[2,   403] loss: 0.424
[2,   404] loss: 0.425
[2,   405] loss: 0.426
[2,   406] loss: 0.432
[2,   407] loss: 0.421
[2,   408] loss: 0.408
[2,   409] loss: 0.422
[2,   410] loss: 0.416
[2,   411] loss: 0.434
[2,   412] loss: 0.425
[2,   413] loss: 0.419
[2,   414] loss: 0.429
[2,   415] loss: 0.423
[2,   416] loss: 0.430
[2,   417] loss: 0.428
[2,   418] loss: 0.420
[2,   419] loss: 0.428
[2,   420] loss: 0.418
[2,   421] loss: 0.408
[2,   422] loss: 0.421
[2,   423] loss: 0.431
[2,   424] loss: 0.422
[2,   425] loss: 0.408
[2,   426] loss: 0.421
[2,   427] loss: 0.424
[2,   428] loss: 0.434
[2,   429] loss: 0.418
[2,   430] loss: 0.417
[2,   431] 

[3,    16] loss: 0.422
[3,    17] loss: 0.420
[3,    18] loss: 0.429
[3,    19] loss: 0.418
[3,    20] loss: 0.407
[3,    21] loss: 0.404
[3,    22] loss: 0.401
[3,    23] loss: 0.419
[3,    24] loss: 0.427
[3,    25] loss: 0.426
[3,    26] loss: 0.432
[3,    27] loss: 0.409
[3,    28] loss: 0.416
[3,    29] loss: 0.401
[3,    30] loss: 0.414
[3,    31] loss: 0.426
[3,    32] loss: 0.425
[3,    33] loss: 0.429
[3,    34] loss: 0.410
[3,    35] loss: 0.418
[3,    36] loss: 0.434
[3,    37] loss: 0.423
[3,    38] loss: 0.421
[3,    39] loss: 0.438
[3,    40] loss: 0.411
[3,    41] loss: 0.406
[3,    42] loss: 0.404
[3,    43] loss: 0.420
[3,    44] loss: 0.426
[3,    45] loss: 0.420
[3,    46] loss: 0.413
[3,    47] loss: 0.419
[3,    48] loss: 0.416
[3,    49] loss: 0.434
[3,    50] loss: 0.432
[3,    51] loss: 0.428
[3,    52] loss: 0.424
[3,    53] loss: 0.422
[3,    54] loss: 0.419
[3,    55] loss: 0.423
[3,    56] loss: 0.422
[3,    57] loss: 0.410
[3,    58] loss: 0.426
[3,    59] 

[3,   373] loss: 0.415
[3,   374] loss: 0.418
[3,   375] loss: 0.424
[3,   376] loss: 0.400
[3,   377] loss: 0.425
[3,   378] loss: 0.418
[3,   379] loss: 0.421
[3,   380] loss: 0.411
[3,   381] loss: 0.425
[3,   382] loss: 0.425
[3,   383] loss: 0.415
[3,   384] loss: 0.426
[3,   385] loss: 0.418
[3,   386] loss: 0.414
[3,   387] loss: 0.403
[3,   388] loss: 0.405
[3,   389] loss: 0.418
[3,   390] loss: 0.414
[3,   391] loss: 0.414
[3,   392] loss: 0.419
[3,   393] loss: 0.418
[3,   394] loss: 0.422
[3,   395] loss: 0.409
[3,   396] loss: 0.427
[3,   397] loss: 0.426
[3,   398] loss: 0.413
[3,   399] loss: 0.403
[3,   400] loss: 0.401
[3,   401] loss: 0.418
[3,   402] loss: 0.432
[3,   403] loss: 0.413
[3,   404] loss: 0.413
[3,   405] loss: 0.410
[3,   406] loss: 0.434
[3,   407] loss: 0.422
[3,   408] loss: 0.409
[3,   409] loss: 0.421
[3,   410] loss: 0.427
[3,   411] loss: 0.421
[3,   412] loss: 0.428
[3,   413] loss: 0.422
[3,   414] loss: 0.417
[3,   415] loss: 0.424
[3,   416] 

[4,     1] loss: 0.402
[4,     2] loss: 0.427
[4,     3] loss: 0.419
[4,     4] loss: 0.428
[4,     5] loss: 0.413
[4,     6] loss: 0.413
[4,     7] loss: 0.424
[4,     8] loss: 0.394
[4,     9] loss: 0.394
[4,    10] loss: 0.407
[4,    11] loss: 0.422
[4,    12] loss: 0.423
[4,    13] loss: 0.421
[4,    14] loss: 0.407
[4,    15] loss: 0.418
[4,    16] loss: 0.433
[4,    17] loss: 0.430
[4,    18] loss: 0.417
[4,    19] loss: 0.416
[4,    20] loss: 0.415
[4,    21] loss: 0.421
[4,    22] loss: 0.406
[4,    23] loss: 0.426
[4,    24] loss: 0.416
[4,    25] loss: 0.431
[4,    26] loss: 0.409
[4,    27] loss: 0.426
[4,    28] loss: 0.421
[4,    29] loss: 0.410
[4,    30] loss: 0.426
[4,    31] loss: 0.410
[4,    32] loss: 0.426
[4,    33] loss: 0.421
[4,    34] loss: 0.417
[4,    35] loss: 0.412
[4,    36] loss: 0.401
[4,    37] loss: 0.428
[4,    38] loss: 0.426
[4,    39] loss: 0.411
[4,    40] loss: 0.408
[4,    41] loss: 0.408
[4,    42] loss: 0.417
[4,    43] loss: 0.414
[4,    44] 

[4,   358] loss: 0.430
[4,   359] loss: 0.418
[4,   360] loss: 0.431
[4,   361] loss: 0.417
[4,   362] loss: 0.424
[4,   363] loss: 0.411
[4,   364] loss: 0.416
[4,   365] loss: 0.407
[4,   366] loss: 0.432
[4,   367] loss: 0.414
[4,   368] loss: 0.415
[4,   369] loss: 0.407
[4,   370] loss: 0.409
[4,   371] loss: 0.418
[4,   372] loss: 0.413
[4,   373] loss: 0.413
[4,   374] loss: 0.426
[4,   375] loss: 0.426
[4,   376] loss: 0.423
[4,   377] loss: 0.428
[4,   378] loss: 0.411
[4,   379] loss: 0.426
[4,   380] loss: 0.413
[4,   381] loss: 0.421
[4,   382] loss: 0.401
[4,   383] loss: 0.418
[4,   384] loss: 0.405
[4,   385] loss: 0.421
[4,   386] loss: 0.433
[4,   387] loss: 0.410
[4,   388] loss: 0.416
[4,   389] loss: 0.413
[4,   390] loss: 0.414
[4,   391] loss: 0.419
[4,   392] loss: 0.418
[4,   393] loss: 0.414
[4,   394] loss: 0.410
[4,   395] loss: 0.420
[4,   396] loss: 0.414
[4,   397] loss: 0.424
[4,   398] loss: 0.394
[4,   399] loss: 0.417
[4,   400] loss: 0.425
[4,   401] 

[4,   715] loss: 0.404
[4,   716] loss: 0.432
[4,   717] loss: 0.421
[4,   718] loss: 0.413
[4,   719] loss: 0.415
[4,   720] loss: 0.424
[4,   721] loss: 0.417
[4,   722] loss: 0.433
[4,   723] loss: 0.414
[4,   724] loss: 0.414
[4,   725] loss: 0.430
[4,   726] loss: 0.422
[4,   727] loss: 0.418
[4,   728] loss: 0.395
[4,   729] loss: 0.434
[5,     1] loss: 0.416
[5,     2] loss: 0.426
[5,     3] loss: 0.430
[5,     4] loss: 0.390
[5,     5] loss: 0.424
[5,     6] loss: 0.420
[5,     7] loss: 0.409
[5,     8] loss: 0.422
[5,     9] loss: 0.422
[5,    10] loss: 0.409
[5,    11] loss: 0.430
[5,    12] loss: 0.396
[5,    13] loss: 0.418
[5,    14] loss: 0.427
[5,    15] loss: 0.409
[5,    16] loss: 0.409
[5,    17] loss: 0.408
[5,    18] loss: 0.412
[5,    19] loss: 0.427
[5,    20] loss: 0.409
[5,    21] loss: 0.402
[5,    22] loss: 0.397
[5,    23] loss: 0.412
[5,    24] loss: 0.426
[5,    25] loss: 0.412
[5,    26] loss: 0.418
[5,    27] loss: 0.414
[5,    28] loss: 0.410
[5,    29] 

[5,   343] loss: 0.410
[5,   344] loss: 0.414
[5,   345] loss: 0.421
[5,   346] loss: 0.414
[5,   347] loss: 0.413
[5,   348] loss: 0.426
[5,   349] loss: 0.422
[5,   350] loss: 0.403
[5,   351] loss: 0.415
[5,   352] loss: 0.424
[5,   353] loss: 0.414
[5,   354] loss: 0.430
[5,   355] loss: 0.422
[5,   356] loss: 0.405
[5,   357] loss: 0.418
[5,   358] loss: 0.417
[5,   359] loss: 0.423
[5,   360] loss: 0.418
[5,   361] loss: 0.417
[5,   362] loss: 0.413
[5,   363] loss: 0.425
[5,   364] loss: 0.422
[5,   365] loss: 0.393
[5,   366] loss: 0.407
[5,   367] loss: 0.417
[5,   368] loss: 0.424
[5,   369] loss: 0.413
[5,   370] loss: 0.422
[5,   371] loss: 0.406
[5,   372] loss: 0.419
[5,   373] loss: 0.397
[5,   374] loss: 0.414
[5,   375] loss: 0.418
[5,   376] loss: 0.401
[5,   377] loss: 0.413
[5,   378] loss: 0.426
[5,   379] loss: 0.430
[5,   380] loss: 0.416
[5,   381] loss: 0.421
[5,   382] loss: 0.424
[5,   383] loss: 0.399
[5,   384] loss: 0.427
[5,   385] loss: 0.412
[5,   386] 

[5,   700] loss: 0.425
[5,   701] loss: 0.418
[5,   702] loss: 0.428
[5,   703] loss: 0.423
[5,   704] loss: 0.422
[5,   705] loss: 0.419
[5,   706] loss: 0.430
[5,   707] loss: 0.423
[5,   708] loss: 0.411
[5,   709] loss: 0.420
[5,   710] loss: 0.425
[5,   711] loss: 0.413
[5,   712] loss: 0.418
[5,   713] loss: 0.417
[5,   714] loss: 0.401
[5,   715] loss: 0.422
[5,   716] loss: 0.418
[5,   717] loss: 0.407
[5,   718] loss: 0.426
[5,   719] loss: 0.420
[5,   720] loss: 0.404
[5,   721] loss: 0.430
[5,   722] loss: 0.409
[5,   723] loss: 0.411
[5,   724] loss: 0.418
[5,   725] loss: 0.413
[5,   726] loss: 0.417
[5,   727] loss: 0.398
[5,   728] loss: 0.421
[5,   729] loss: 0.421
[6,     1] loss: 0.413
[6,     2] loss: 0.422
[6,     3] loss: 0.401
[6,     4] loss: 0.416
[6,     5] loss: 0.417
[6,     6] loss: 0.399
[6,     7] loss: 0.399
[6,     8] loss: 0.402
[6,     9] loss: 0.410
[6,    10] loss: 0.423
[6,    11] loss: 0.394
[6,    12] loss: 0.414
[6,    13] loss: 0.409
[6,    14] 

[6,   328] loss: 0.414
[6,   329] loss: 0.410
[6,   330] loss: 0.417
[6,   331] loss: 0.412
[6,   332] loss: 0.426
[6,   333] loss: 0.418
[6,   334] loss: 0.430
[6,   335] loss: 0.420
[6,   336] loss: 0.409
[6,   337] loss: 0.422
[6,   338] loss: 0.414
[6,   339] loss: 0.420
[6,   340] loss: 0.405
[6,   341] loss: 0.410
[6,   342] loss: 0.426
[6,   343] loss: 0.413
[6,   344] loss: 0.415
[6,   345] loss: 0.405
[6,   346] loss: 0.409
[6,   347] loss: 0.415
[6,   348] loss: 0.414
[6,   349] loss: 0.409
[6,   350] loss: 0.412
[6,   351] loss: 0.421
[6,   352] loss: 0.426
[6,   353] loss: 0.402
[6,   354] loss: 0.407
[6,   355] loss: 0.392
[6,   356] loss: 0.416
[6,   357] loss: 0.414
[6,   358] loss: 0.406
[6,   359] loss: 0.409
[6,   360] loss: 0.405
[6,   361] loss: 0.417
[6,   362] loss: 0.419
[6,   363] loss: 0.410
[6,   364] loss: 0.422
[6,   365] loss: 0.409
[6,   366] loss: 0.423
[6,   367] loss: 0.424
[6,   368] loss: 0.418
[6,   369] loss: 0.425
[6,   370] loss: 0.418
[6,   371] 

[6,   685] loss: 0.424
[6,   686] loss: 0.419
[6,   687] loss: 0.419
[6,   688] loss: 0.421
[6,   689] loss: 0.406
[6,   690] loss: 0.423
[6,   691] loss: 0.414
[6,   692] loss: 0.408
[6,   693] loss: 0.415
[6,   694] loss: 0.402
[6,   695] loss: 0.414
[6,   696] loss: 0.407
[6,   697] loss: 0.414
[6,   698] loss: 0.406
[6,   699] loss: 0.422
[6,   700] loss: 0.407
[6,   701] loss: 0.424
[6,   702] loss: 0.409
[6,   703] loss: 0.414
[6,   704] loss: 0.414
[6,   705] loss: 0.413
[6,   706] loss: 0.438
[6,   707] loss: 0.434
[6,   708] loss: 0.410
[6,   709] loss: 0.414
[6,   710] loss: 0.427
[6,   711] loss: 0.398
[6,   712] loss: 0.389
[6,   713] loss: 0.422
[6,   714] loss: 0.431
[6,   715] loss: 0.410
[6,   716] loss: 0.411
[6,   717] loss: 0.424
[6,   718] loss: 0.421
[6,   719] loss: 0.409
[6,   720] loss: 0.409
[6,   721] loss: 0.413
[6,   722] loss: 0.413
[6,   723] loss: 0.421
[6,   724] loss: 0.408
[6,   725] loss: 0.420
[6,   726] loss: 0.417
[6,   727] loss: 0.409
[6,   728] 

[7,   313] loss: 0.409
[7,   314] loss: 0.404
[7,   315] loss: 0.433
[7,   316] loss: 0.412
[7,   317] loss: 0.422
[7,   318] loss: 0.414
[7,   319] loss: 0.420
[7,   320] loss: 0.409
[7,   321] loss: 0.427
[7,   322] loss: 0.401
[7,   323] loss: 0.417
[7,   324] loss: 0.415
[7,   325] loss: 0.405
[7,   326] loss: 0.421
[7,   327] loss: 0.400
[7,   328] loss: 0.419
[7,   329] loss: 0.395
[7,   330] loss: 0.426
[7,   331] loss: 0.428
[7,   332] loss: 0.418
[7,   333] loss: 0.430
[7,   334] loss: 0.415
[7,   335] loss: 0.427
[7,   336] loss: 0.414
[7,   337] loss: 0.423
[7,   338] loss: 0.416
[7,   339] loss: 0.409
[7,   340] loss: 0.422
[7,   341] loss: 0.397
[7,   342] loss: 0.420
[7,   343] loss: 0.417
[7,   344] loss: 0.411
[7,   345] loss: 0.425
[7,   346] loss: 0.407
[7,   347] loss: 0.418
[7,   348] loss: 0.410
[7,   349] loss: 0.398
[7,   350] loss: 0.418
[7,   351] loss: 0.421
[7,   352] loss: 0.412
[7,   353] loss: 0.418
[7,   354] loss: 0.406
[7,   355] loss: 0.387
[7,   356] 

In [None]:
# for k, v in my_model.named_parameters():
#     print(k, v.sum().item(), v.requires_grad)

In [None]:
# PATH = './yoga_net.pth'
# torch.save(my_model.state_dict(), PATH)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
import pickle


In [None]:
# Run on test set
test_loader = DataLoader(test_set, batch_size=16, shuffle=False, pin_memory=True)

correct = 0
total = 0

test_results = []

# my_model.eval()
#pickle.dump(my_model, open("densepose_model.sav", 'wb'))
#my_model = pickle.load(open("densepose_model.sav", 'rb'))

with torch.no_grad():
    for data in test_loader:
        x, y = data
        x = x.cuda()
        y = y.cuda()
        out = my_model(x)
        _, predicted = torch.max(out.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum()
        for p in predicted.detach().cpu().numpy():
            test_results.append(p)

print(f'Accuracy of the network on the {total} test images: {100 * correct / total}%')

In [None]:
# Confusion matrix
nb_classes = 10
all_predicted = []
all_y = []
with torch.no_grad():
    for data in test_loader:
        x, y = data
        x = x.cuda()
        y = y.cuda()
        out = my_model(x)
        _, predicted = torch.max(out.data, 1)
        for p in predicted.detach().cpu().numpy():
            all_predicted.append(p)
        for g in y.detach().cpu().numpy():
            all_y.append(g)

confusion_matrix(all_predicted, all_y)


In [None]:
for i in range(len(test_set)):
    print(test_set.data[i], test_set.all_classes[test_results[i]])