In [1]:
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision import transforms


def get_transform(input_size=224):
    return transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

class PlacesDataset(Dataset):
    def __init__(self, dataset_npy:str, transform = None, onlylabels=None):
        self.data = []
        self.classes = {
            'bathroom': 0,
            'bedroom': 1,
            'childs_room': 2,
            'classroom': 3,
            'dressing_room': 4,
            'living_room': 5,
            'studio': 6,
            'swimming_pool': 7
        }
        self.idx_to_class = {str(idx): value for idx, value in enumerate(self.classes.keys())}
        
        self.onlylabels = onlylabels
        self.transform = get_transform()
        reader = np.load(dataset_npy)
        for [img_path, label] in reader:
            self.data.append((img_path, self.classes[label])) #for some reason, label is str
        self.data = np.array(self.data) # data = [['fullpath', 'label'], ....]
        print("dataset_size", self.data.size)
        
        if self.onlylabels is not None:
            self.onlylabels = [int(i) for i in self.onlylabels]
            clip_indexes = np.where(self.data[:, 1] == str(self.onlylabels[0]))[0] # indexes
            for i in self.onlylabels[1:]:
                clip = np.where(self.data[:, 1] == str(i))[0]
                clip_indexes = np.append(clip_indexes, clip)
            clip_indexes.sort()
            self.data = self.data[clip_indexes]
        labels, counts = np.unique(self.data[:, 1], return_counts = True)
        self.labels = labels.astype(int) # labels are integers folowwing self.classes
        

        # Calculate class weights for WeightedRandomSampler
        self.class_counts = dict(zip(labels, counts))
        self.class_weights = {label: max(self.class_counts.values()) / count
                              for label, count in self.class_counts.items()}
        self.sampler_weights = [self.class_weights[cls] for cls in self.data[:, 1]]
        
        self.class_weights_list = [self.class_weights[k]
                                   for k in sorted(self.class_weights)]

        print('Found {} images from {} classes.'.format(len(self.data),
                                                        len(self.labels)))
        for idx in self.class_counts.keys():
            print("    Class '{}' ({}): {} images.".format(
                  self.idx_to_class[idx], idx, self.class_counts[idx]))
            

    def __getitem__(self, index:int):
        img_path, label = self.data[index]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(int(label))
        return image, label

    def __len__(self):
        return len(self.data)

In [2]:
# from dataloader_places import PlacesDataset
from torch.utils.data import DataLoader

places_dataset = PlacesDataset('Places8_paths_and_labels_complete_train.npy')

dataset_size 729612
Found 364806 images from 8 classes.
    Class 'bathroom' (0): 51655 images.
    Class 'bedroom' (1): 100012 images.
    Class 'childs_room' (2): 41849 images.
    Class 'classroom' (3): 33763 images.
    Class 'dressing_room' (4): 21889 images.
    Class 'living_room' (5): 89458 images.
    Class 'studio' (6): 12633 images.
    Class 'swimming_pool' (7): 13547 images.


In [3]:
import torchvision
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.fc = torch.nn.Identity()
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
dataloader = DataLoader(places_dataset, batch_size=16)
for batch in dataloader:
    print(resnet18(batch[0]).shape)
    break

torch.Size([16, 512])
