In [None]:
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [None]:
!git clone https://github.com/szx159753/FashionDataset.git

Cloning into 'FashionDataset'...
remote: Enumerating objects: 14, done.[K
remote: Counting objects: 100% (14/14), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 14 (delta 2), reused 14 (delta 2), pack-reused 0[K
Unpacking objects: 100% (14/14), done.
Checking out files: 100% (6/6), done.


In [None]:
!mv FashionDataset/focal_loss.py /content/

In [None]:
!tar xvf FashionDataset/data.tar

In [None]:
from __future__ import print_function, division
import os

from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from collections import defaultdict
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
import subprocess
from focal_loss import FocalLoss

In [None]:
data_dir='FashionDataset'
classes=[7,3,3,4,6,3]
# files={}
# labels={}
# for dir in ['train','val']:
#     files[dir]=open(os.path.join(data_dir,'split/'+dir+'.txt')).read().split('\n')
#     labels[dir]=open(os.path.join(data_dir,'split/'+dir+'_attr.txt')).read().split('\n')

In [109]:
class MyDataset(Dataset):
    def __init__(self,dir,y_label=False):
        self.transform = transforms.Compose([
            #transforms.RandomSizedCrop(224),
            transforms.Resize([224,224]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.files=open(os.path.join(data_dir,'split/'+dir+'.txt')).read().split('\n')[0:-1]
        self.len=len(self.files)
        self.y_label=y_label

        bbox=open(os.path.join(data_dir,'split/'+dir+'_bbox.txt')).read().split('\n')[0:-1]
        bbox=[bb.split(' ') for bb in bbox]
        self.bbox=bbox
        if self.y_label is True:
            labels=open(os.path.join(data_dir,'split/'+dir+'_attr.txt')).read().split('\n')[0:-1]
            labels=[li.split(' ') for li in labels]
            self.labels=labels
    def __getitem__(self, idx):
        img_obj = Image.open(os.path.join(data_dir,self.files[idx]))
        bbox = np.array([int(l) for l in self.bbox[idx]],dtype=np.int)
        img_obj = img_obj.crop([bbox[0],bbox[1],bbox[2],bbox[3]])
        img_obj.save("img.jpg")
        img_obj=self.transform(img_obj)
        if self.y_label is True:
            labels = np.array([int(l) for l in self.labels[idx]],dtype=np.float32)
            labels = torch.from_numpy(labels)
            return img_obj,labels
        else:
            return img_obj,self.files[idx]
    def __len__(self):
        return len(self.files)

In [110]:
image_datasets={x:MyDataset(x,y)
          for x,y in [['train',True],['val',True],['test',False]]}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16,
                                             shuffle=y, num_workers=2)
              for x,y in [['train',False],['val',False],['test',False]]}
unloader = transforms.ToPILImage()
img=unloader(next(iter(dataloaders['train']))[0][6])
img.save("img.jpg")