In [11]:
import pandas as pd
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

data_path = "/run/media/kevin/Volume/OpenImages/"

In [12]:
class ImageDataset(Dataset):
    """Test Class"""
    
    def __init__(self, label_file, root_dir, transform=None):
        self.labels = pd.read_csv(label_file)
        self.root_dir = root_dir
        self.transform = transform
        
        # Make some adjustments
        self.labels.index = self.labels['ImageID']
        self.labels = self.labels[self.labels['Confidence'] == 1]
        self.label_names = np.sort(list(set(self.labels['LabelName'])))
        self.images = list(set(self.labels['ImageID']))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        """ Get an image"""
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        # Load image
        img_id = self.images[idx]
        img_path = os.path.join(self.root_dir, self.images[idx]+".jpg")
        image = Image.open(img_path)
        
        # Create label vector
        img_objects = np.array(self.labels.loc[img_id,]['LabelName'])
        image_label = np.array([1 if x in img_objects else 0 for x in self.label_names])
        
        # Apply transform
        if self.transform is not None:
            image = self.transform(image)
        
        image = np.asarray(image)
        sample = {'image': image, 'label': image_label}
        
        return sample
    
        

In [13]:
transform = transforms.Compose(
    [transforms.Scale((299, 299)),
     transforms.Grayscale(3),
     transforms.ToTensor()])

root_dir = os.path.join(data_path, "pics")
csv_path = os.path.join(data_path, "subset_train_annots.csv")
dataset = ImageDataset(label_file = csv_path, root_dir = root_dir, transform=transform)

## Model

In [14]:
n_classes = len(dataset.label_names)
print(f"{n_classes} classes")

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(in_features=512, out_features=n_classes, bias=True)

# specify device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

data_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=6)
len(data_loader)

601 classes


99520

In [15]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, batch in enumerate(data_loader):

        # get the inputs; data is a list of [inputs, labels]
        imgs = batch['image'].to(device)
        labels = batch['label'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(imgs)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 0:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.6f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f154650aef0>
Traceback (most recent call last):
  File "/home/kevin/.conda/envs/image/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 926, in __del__
    self._shutdown_workers()
  File "/home/kevin/.conda/envs/image/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 906, in _shutdown_workers
    w.join()
  File "/home/kevin/.conda/envs/image/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f154650aef0>
Traceback (most recent call last):
  File "/home/kevin/.conda/envs/image/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 926, in __del__
    self._shutdown_workers()
  File "/home/kevin/.conda/envs/image/lib/python3.7/site-packages/torch/utils/data/datal

[1,     1] loss: 0.007233
[1,   101] loss: 0.039641
[1,   201] loss: 0.024538
[1,   301] loss: 0.024384
[1,   401] loss: 0.024680
[1,   501] loss: 0.024036
[1,   601] loss: 0.023200
[1,   701] loss: 0.025399
[1,   801] loss: 0.023813
[1,   901] loss: 0.025102
[1,  1001] loss: 0.023788
[1,  1101] loss: 0.023258
[1,  1201] loss: 0.022903
[1,  1301] loss: 0.023659
[1,  1401] loss: 0.022684
[1,  1501] loss: 0.023321
[1,  1601] loss: 0.023264
[1,  1701] loss: 0.023162
[1,  1801] loss: 0.023386
[1,  1901] loss: 0.021925
[1,  2001] loss: 0.023304
[1,  2101] loss: 0.022473
[1,  2201] loss: 0.024001
[1,  2301] loss: 0.023161
[1,  2401] loss: 0.022656
[1,  2501] loss: 0.022059
[1,  2601] loss: 0.023003
[1,  2701] loss: 0.022626
[1,  2801] loss: 0.022765
[1,  2901] loss: 0.021956
[1,  3001] loss: 0.022633
[1,  3101] loss: 0.022269
[1,  3201] loss: 0.022318
[1,  3301] loss: 0.022614
[1,  3401] loss: 0.021961
[1,  3501] loss: 0.022877
[1,  3601] loss: 0.021207
[1,  3701] loss: 0.021659
[1,  3801] l

[1, 31601] loss: 0.017462
[1, 31701] loss: 0.017664
[1, 31801] loss: 0.017949
[1, 31901] loss: 0.017860
[1, 32001] loss: 0.017661
[1, 32101] loss: 0.017582
[1, 32201] loss: 0.017315
[1, 32301] loss: 0.017006
[1, 32401] loss: 0.018232
[1, 32501] loss: 0.017360
[1, 32601] loss: 0.017828
[1, 32701] loss: 0.017460
[1, 32801] loss: 0.017199
[1, 32901] loss: 0.017610
[1, 33001] loss: 0.017882
[1, 33101] loss: 0.017142
[1, 33201] loss: 0.017671
[1, 33301] loss: 0.017590
[1, 33401] loss: 0.017638
[1, 33501] loss: 0.017748
[1, 33601] loss: 0.017628
[1, 33701] loss: 0.017184
[1, 33801] loss: 0.017490
[1, 33901] loss: 0.018016
[1, 34001] loss: 0.016658
[1, 34101] loss: 0.018314
[1, 34201] loss: 0.017049
[1, 34301] loss: 0.018619
[1, 34401] loss: 0.017744
[1, 34501] loss: 0.018338
[1, 34601] loss: 0.017063
[1, 34701] loss: 0.017006
[1, 34801] loss: 0.017107
[1, 34901] loss: 0.018366
[1, 35001] loss: 0.017314
[1, 35101] loss: 0.017268
[1, 35201] loss: 0.017219
[1, 35301] loss: 0.016972
[1, 35401] l

[1, 63201] loss: 0.015888
[1, 63301] loss: 0.015857
[1, 63401] loss: 0.015095
[1, 63501] loss: 0.016056
[1, 63601] loss: 0.015502
[1, 63701] loss: 0.014887
[1, 63801] loss: 0.015555
[1, 63901] loss: 0.015508
[1, 64001] loss: 0.015727
[1, 64101] loss: 0.016066
[1, 64201] loss: 0.015824
[1, 64301] loss: 0.016197
[1, 64401] loss: 0.015125
[1, 64501] loss: 0.015748
[1, 64601] loss: 0.015861
[1, 64701] loss: 0.015600
[1, 64801] loss: 0.016596
[1, 64901] loss: 0.016126
[1, 65001] loss: 0.016247
[1, 65101] loss: 0.016148
[1, 65201] loss: 0.015824
[1, 65301] loss: 0.014826
[1, 65401] loss: 0.016080
[1, 65501] loss: 0.015597
[1, 65601] loss: 0.015510
[1, 65701] loss: 0.016382
[1, 65801] loss: 0.015501
[1, 65901] loss: 0.015607
[1, 66001] loss: 0.015559
[1, 66101] loss: 0.015693
[1, 66201] loss: 0.015944
[1, 66301] loss: 0.015593
[1, 66401] loss: 0.016277
[1, 66501] loss: 0.015907
[1, 66601] loss: 0.015605
[1, 66701] loss: 0.015912
[1, 66801] loss: 0.015199
[1, 66901] loss: 0.015655
[1, 67001] l

[1, 94801] loss: 0.014118
[1, 94901] loss: 0.015313
[1, 95001] loss: 0.014441
[1, 95101] loss: 0.014582
[1, 95201] loss: 0.015196
[1, 95301] loss: 0.014525
[1, 95401] loss: 0.014945
[1, 95501] loss: 0.014326
[1, 95601] loss: 0.014802
[1, 95701] loss: 0.014122
[1, 95801] loss: 0.014841
[1, 95901] loss: 0.014537
[1, 96001] loss: 0.015220
[1, 96101] loss: 0.015232
[1, 96201] loss: 0.015362
[1, 96301] loss: 0.014325
[1, 96401] loss: 0.014509
[1, 96501] loss: 0.015251
[1, 96601] loss: 0.014274
[1, 96701] loss: 0.014909
[1, 96801] loss: 0.015352
[1, 96901] loss: 0.014561
[1, 97001] loss: 0.014696
[1, 97101] loss: 0.015037
[1, 97201] loss: 0.014838
[1, 97301] loss: 0.014062
[1, 97401] loss: 0.015021
[1, 97501] loss: 0.014656
[1, 97601] loss: 0.014815
[1, 97701] loss: 0.014027
[1, 97801] loss: 0.014483
[1, 97901] loss: 0.014502
[1, 98001] loss: 0.014515
[1, 98101] loss: 0.014498
[1, 98201] loss: 0.014674
[1, 98301] loss: 0.014347
[1, 98401] loss: 0.015216
[1, 98501] loss: 0.015095
[1, 98601] l

KeyboardInterrupt: 