In [3]:
# [STAR] All the Imports

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

from pathlib import Path
import ast
import pandas as pd
from PIL import Image

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler

from matplotlib import pyplot as plt
import re
import cv2

import random
import glob

import csv
from scipy import ndimage, misc
from tqdm import tqdm

In [4]:
# [STAR] Attribute and Category Model

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

class MyAttrCateModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model    = models.resnet18(pretrained=True)
        self.model.fc = Identity()
        
        #self.attr_layer = nn.Sequential(nn.Linear(512, 128, bias=False), 
        #                                nn.ReLU(inplace=True),
        #                                nn.Linear(128, 26, bias=False)
        #                               )
        
        #self.cate_layer = nn.Sequential(nn.Linear(512, 128, bias=False), 
        #                                nn.ReLU(inplace=True),
        #                                nn.Linear(128, 50, bias=False))
        self.attr_layer = nn.Linear(512, 26)
        self.cate_layer = nn.Linear(512, 50)
    
    def forward(self, x):
        out1     = self.model(x)
        attr_out = self.attr_layer(out1)
        cate_out = self.cate_layer(out1)
        #cate_out = torch.flatten(cate_out)
        return attr_out, cate_out

class MyAttrCateModel50(nn.Module):
    def __init__(self):
        super().__init__()
        self.model    = models.resnet50(pretrained=True)
        self.model.fc = Identity()
        
        #self.attr_layer = nn.Sequential(nn.Linear(2048, 128, bias=False), 
        #                                nn.ReLU(inplace=True),
        #                                nn.Linear(128, 26, bias=False))
        #self.cate_layer = nn.Sequential(nn.Linear(2048, 128, bias=False), 
        #                                nn.ReLU(inplace=True),
        #                                nn.Linear(128, 50, bias=False))
        self.attr_layer = nn.Linear(2048, 26)
        self.cate_layer = nn.Linear(2048, 50)
    
    def forward(self, x):
        out1     = self.model(x)
        attr_out = self.attr_layer(out1)
        cate_out = self.cate_layer(out1)
        #cate_out = torch.flatten(cate_out)
        return attr_out, cate_out
#model  = MyAttrCateModel()
# x      = torch.randn(1, 3, 224, 224)
# output = model(x)
# print(output[0].shape, output[1].shape)

#print(model)
#model_ft = model_ft.to(device)

In [5]:
# [CLEAN CODE]

from __future__ import division
import os

import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data.dataset import Dataset


class AttrDataset(Dataset):
    CLASSES = None
    
    def __init__(self,
                 img_path,
                 img_file,
                 label_file,
                 cate_file,
                 bbox_file,
                 landmark_file,
                 img_size,
                 idx2id=None):
        self.img_path = img_path

        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        # read img names
        fp = open(img_file, 'r')
        self.img_list = [x.strip() for x in fp]

        # read attribute labels and category annotations
        self.labels = np.loadtxt(label_file, dtype=np.float32)

        # read categories
        self.categories = []
        catefn = open(cate_file).readlines()
        for i, line in enumerate(catefn):
            self.categories.append(line.strip('\n'))

        self.img_size = img_size
    
    def get_basic_item(self, idx):
        print(os.path.join(self.img_path, self.img_list[idx]))
        img = Image.open(os.path.join(self.img_path,
                                      self.img_list[idx])).convert('RGB')

        width, height  = img.size
        print('Original Image size is ', width, height)
        # Very Important
        # For getting the cropped and resized region of interest image
        img.thumbnail(self.img_size, Image.ANTIALIAS)
        img   = self.transform(img)

        label    = torch.from_numpy(self.labels[idx])
        cate     = torch.LongTensor([int(self.categories[idx]) - 1])

        data = {'img': img, 'attr': label, 'cate': cate}
        return data

    def __getitem__(self, idx):
        return self.get_basic_item(idx)

    def __len__(self):
        return len(self.img_list)


class ValidAttrDataset(Dataset):
    CLASSES = None
    
    def __init__(self,
                 img_path,
                 img_file,
                 label_file,
                 cate_file,
                 bbox_file,
                 landmark_file,
                 img_size,
                 idx2id=None):
        self.img_path = img_path

        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        # read img names
        fp = open(img_file, 'r').read().split('\n')
        self.img_list = fp#[x.strip() for x in fp]
        
        self.img_size = img_size
    
    def get_basic_item(self, idx):
        try:
            #print(os.path.join(self.img_path, self.img_list[idx]))
            img = Image.open(os.path.join(self.img_path,
                                          self.img_list[idx])).convert('RGB')

            width, height  = img.size
            img.resize((width//2, height//2))
            width, height  = img.size

            #print('Original Image size is ', width, height)
            # Very Important
            # For getting the cropped and resized region of interest image
            img.thumbnail(self.img_size, Image.ANTIALIAS)
            img   = self.transform(img)

            #label    = torch.from_numpy(self.labels[idx])
            #cate     = torch.LongTensor([int(self.categories[idx]) - 1])

            data = {'img': img, 'imgpath': os.path.join(self.img_path,
                                          self.img_list[idx])}#, 'attr': label, 'cate': cate}
            return data
        except:
            pass

    def __getitem__(self, idx):
        return self.get_basic_item(idx)

    def __len__(self):
        return len(self.img_list)

img_path   = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Img/"
img_file   = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/train.txt"
label_file = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/train_attr.txt"
cate_file  = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/train_cate.txt"
img_size   = [224, 224]
#img_size   = [256, 256]

landmark_file = None
bbox_file     = None

img_file   = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/val.txt"
label_file = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/val_attr.txt"
cate_file  = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/val_cate.txt"


# Change these as per your need
img_path   = "/home/yu-hao/Downloads/30 Styles-20210216T225758Z-001/30 Styles/"
img_file   = "/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/newval.txt"

d3 = ValidAttrDataset(img_path, img_file, label_file, cate_file, bbox_file, landmark_file, img_size, idx2id=None)

from torch.utils.data import DataLoader


def build_dataloader(dataset, batch_size, shuffle):
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
        pin_memory=False)
    return data_loader


model  = MyAttrCateModel50()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
params       = [p for p in model.parameters() if p.requires_grad]
optimizer    = torch.optim.Adam(params, lr=0.0001, weight_decay=0.0001)
lr_scheduler = None


In [6]:
# [STAR] MMFASHION Testing on a single image

attr_list = []
attr_list_file = open("/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/list_attr_cloth.txt").read().split('\n')
for t in attr_list_file[2:-1]:
    attr_list.append(t.split()[0])
attr_list = np.array(attr_list)

cate_list = []
cate_list_file = open("/media/yu-hao/WindowsData/MMFASHION-DATASET/CategoryandAttributePredictionBenchmark/Anno_fine/list_category_cloth.txt").read().split('\n')
for t in cate_list_file[2:-1]:
    cate_list.append(t.split()[0])
cate_list = np.array(cate_list)

model.load_state_dict(torch.load('fashion_cate_attr_resnet50_single_linear.pth'))
model.eval()


index = random.randint(0, len(d3))
t1    = d3[index]

new_images  = torch.Tensor(np.expand_dims(t1['img'], 0)).to(device)
#attr_target = t1['attr'].to(device)
#cate_target = t1['cate'].to(device)
        
out1, out2  = model(new_images)
        
out1 = torch.sigmoid(out1)
out2 = torch.softmax(out2, axis=1)

out1 = out1.data.cpu().numpy().flatten()
out2 = out2.data.cpu().numpy().flatten()

out1[out1 < 0.5] = 0
out1 = np.array(out1.flatten())

attr_index         = np.array(np.nonzero(out1))
#attr_ground_index  = np.array(np.nonzero(t1['attr']).flatten())

cate_index         = np.argmax(out2)
#cate_ground_index  = t1['cate'].data.cpu().numpy()#[0][0]

#print(cate_index, cate_ground_index)
print("Predicted:    ", cate_list[cate_index], attr_list[attr_index])
#print("Ground Truth: ", cate_list[cate_ground_index], attr_list[attr_ground_index])
        

Predicted:     Shorts [['solid' 'sleeveless' 'no_dress' 'no_neckline' 'cotton' 'conventional']]
