In [None]:
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
import os
import cv2
import numpy as np

import csv

imgnet_labels = list(csv.reader(open('imagenet_classes.txt','r'),delimiter='\n'))
imgnet_syns = list(csv.reader(open('imagenet_synsets.txt','r'),delimiter=' '))
synset_dict = {}
for row in imgnet_syns:
    synset_dict[row[0]] = row[1]
label_dict = {}
for i,label in enumerate(imgnet_labels):
    label_name = synset_dict[label[0]]
    label_dict[i] = label_name

Define your dataset

In [None]:
class CelebA_Dataset(data.Dataset):
    def __init__(self,path):
        image_names = os.listdir(path)
        self.image_paths = [os.path.join(path,name) for name in image_names]
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self,idx):
        img_path = self.image_paths[idx]
        img = cv2.imread(img_path)
        img = cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        return img

Create the dataset instance and the torch dataloader

In [None]:
celeba_dataset = CelebA_Dataset("./random_celeba")
print("Number of images:", len(celeba_dataset))
print(celeba_dataset[0].shape)

In [None]:
torch_celeba = DataLoader(celeba_dataset,batch_size=4,shuffle=True,num_workers=1)
print("Number of batches:",len(torch_celeba))

Iterate over batches

In [None]:
for index,data in enumerate(torch_celeba):
    print("Batch number:",index,"\t Shape of data:",data.size())
    # Optionally, if you have a GPU, you can move the data to it 
    #data = data.cuda() 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
import torch
import torchvision

class VGGClass(torch.nn.Module):
    def __init__(self):
        super(VGGClass, self).__init__()
        blocks = []
        self.vggnet = torchvision.models.vgg16(pretrained=True)
        
        self.transform = torch.nn.functional.interpolate
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, input):
        input = (input-self.mean) / self.std
        input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
        vggclass = self.vggnet(input)
        return vggclass
    
vgg_class = VGGClass()

for ind,data in enumerate(torch_celeba):
    
        
    print("Data size:", data.size())
    
    #PYTORCH EXPECTS DATA IN THE FOLLOWING SHAPE : [BATCH_SIZE, NUMBER OF CHANNELS, WIDTH,HEIGHT]
    #OUR DATA LOADER IS RETURNING US THIS SHAPE : [BATCH_SIZE, WIDTH, HEIGHT, NUMBER_OF_CHANNELS]
    
    #SO WE CHANGE THE SHAPE NOW:
    data = data.permute(0,3,1,2)
    
    #CHANGE DATA TYPE TO FLOAT, BECAUSE PYTORCH NETWORKS WANT FLOAT DATA
    
    data = data.float()
    
    
    predicted_class = vgg_class(data)
    print("Shape of Predicted classes",predicted_class.size())
    
    
    imshow(torchvision.utils.make_grid(data))

    class_indices = torch.argmax(predicted_class,1)
    
    for i in class_indices:
        print(label_dict[i.item()])
    break