In [1]:
import os
import glob
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
device

device(type='cuda', index=0)

In [2]:
from dataset import CelebHQAttrDataset, AnnotatedFFHQDataset

In [3]:
data = CelebHQAttrDataset(image_size = 256)

  self.df = pd.read_csv(f, delim_whitespace=True)


In [4]:
data[56]

{'img': tensor([[[-0.7804, -0.7725, -0.7569,  ...,  0.2392,  0.2392,  0.2392],
          [-0.8039, -0.7882, -0.7882,  ...,  0.2392,  0.2392,  0.2392],
          [-0.8275, -0.8118, -0.8039,  ...,  0.2471,  0.2392,  0.2392],
          ...,
          [-0.1216, -0.1137, -0.1216,  ...,  0.0039,  0.0039,  0.0039],
          [-0.1294, -0.1294, -0.1216,  ...,  0.0039,  0.0039,  0.0039],
          [-0.1451, -0.1451, -0.1373,  ...,  0.0039,  0.0039,  0.0039]],
 
         [[-0.8588, -0.8510, -0.8588,  ..., -0.6392, -0.6392, -0.6392],
          [-0.8902, -0.8745, -0.8745,  ..., -0.6392, -0.6392, -0.6392],
          [-0.8980, -0.8824, -0.8745,  ..., -0.6314, -0.6392, -0.6392],
          ...,
          [-0.3804, -0.3725, -0.3804,  ...,  0.0039,  0.0039,  0.0039],
          [-0.3725, -0.3725, -0.3804,  ...,  0.0039,  0.0039,  0.0039],
          [-0.3882, -0.3882, -0.3961,  ...,  0.0039,  0.0039,  0.0039]],
 
         [[-0.9294, -0.9216, -0.9059,  ..., -0.7255, -0.7255, -0.7255],
          [-0.9373, -

In [5]:
data[56]['img'].shape

torch.Size([3, 256, 256])

In [6]:
labels = data[56]['labels']
gt = torch.where(labels > 0,
                             torch.ones_like(labels).float(),
                             torch.zeros_like(labels).float())
gt

tensor([0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0.,
        1., 0., 0., 1.])

In [7]:
data2 = AnnotatedFFHQDataset('/projects/deepdevpath/Saranga/Explaining-In-Style-Reproducibility-Study/data')

In [8]:
data2[0]

{'img': tensor([[[-0.9686, -1.0000, -0.9843,  ..., -0.9922, -1.0000, -1.0000],
          [-0.9765, -1.0000, -0.9922,  ..., -1.0000, -1.0000, -1.0000],
          [-0.9686, -0.9922, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          ...,
          [-0.8353, -0.8196, -0.8118,  ...,  0.0196,  0.0275,  0.0353],
          [-0.8667, -0.8588, -0.8510,  ...,  0.0667,  0.0824,  0.0745],
          [-0.8745, -0.8667, -0.8588,  ...,  0.0275,  0.0588,  0.0588]],
 
         [[ 0.0275, -0.0196,  0.0118,  ..., -0.0667, -0.0902, -0.0824],
          [ 0.0196, -0.0118,  0.0039,  ..., -0.0980, -0.1059, -0.0902],
          [ 0.0275,  0.0039, -0.0039,  ..., -0.1059, -0.1216, -0.1137],
          ...,
          [ 0.2000,  0.2157,  0.2235,  ...,  0.2157,  0.2157,  0.2235],
          [ 0.1608,  0.1686,  0.1765,  ...,  0.2549,  0.2471,  0.2392],
          [ 0.1529,  0.1608,  0.1686,  ...,  0.2157,  0.2235,  0.2235]],
 
         [[ 0.1451,  0.0980,  0.1451,  ...,  0.2157,  0.1843,  0.1922],
          [ 0.1373,  

In [9]:
data2[6392]

{'img': tensor([[[ 0.2941,  0.2941,  0.3020,  ...,  0.5137,  0.5373,  0.4980],
          [ 0.3333,  0.3176,  0.3020,  ...,  0.5137,  0.5294,  0.5137],
          [ 0.3333,  0.3098,  0.3020,  ...,  0.5137,  0.5059,  0.5216],
          ...,
          [ 0.0588, -0.0980, -0.3725,  ...,  0.6235,  0.6000,  0.5922],
          [-0.3176, -0.4039, -0.4667,  ...,  0.6157,  0.6000,  0.6000],
          [-0.4588, -0.5137, -0.4510,  ...,  0.5843,  0.5608,  0.5843]],
 
         [[ 0.0275,  0.0275,  0.0353,  ...,  0.6235,  0.6471,  0.6078],
          [ 0.0667,  0.0510,  0.0353,  ...,  0.6235,  0.6392,  0.6235],
          [ 0.0667,  0.0431,  0.0353,  ...,  0.6392,  0.6314,  0.6471],
          ...,
          [-0.0902, -0.2078, -0.4275,  ...,  0.5843,  0.5608,  0.5529],
          [-0.3490, -0.4196, -0.4824,  ...,  0.5765,  0.5608,  0.5608],
          [-0.4196, -0.4745, -0.4431,  ...,  0.5451,  0.5216,  0.5451]],
 
         [[-0.1843, -0.1843, -0.1765,  ...,  0.6235,  0.6471,  0.6078],
          [-0.1451, -

In [10]:
data2[6392]['img'].shape

torch.Size([3, 256, 256])

In [11]:
data2[6392]['labels']

tensor(1)

In [12]:
classifier = nn.Linear(512, 1)

In [20]:
x = torch.randn(64, 512)
pred = classifier(x)

In [21]:
pred.shape

torch.Size([64, 1])

In [26]:
dl = DataLoader(data2, batch_size=64)

batch = next(iter(dl))


In [33]:
pred

tensor([[-4.6785e-01],
        [ 6.3465e-01],
        [-1.1190e+00],
        [-9.8396e-01],
        [-3.1514e-01],
        [-6.1823e-02],
        [-6.8299e-01],
        [ 3.0822e-01],
        [ 6.6234e-01],
        [ 2.8170e-01],
        [-1.4998e+00],
        [ 1.1529e-01],
        [ 1.0124e+00],
        [ 3.2166e-01],
        [-1.3568e-01],
        [-8.8271e-01],
        [ 4.3810e-01],
        [-4.0126e-02],
        [-2.0673e-01],
        [ 2.2897e-01],
        [-1.0990e+00],
        [-4.4274e-01],
        [ 1.5857e-01],
        [ 4.9157e-01],
        [ 1.9862e-01],
        [-6.3672e-02],
        [-1.7471e-01],
        [-5.9854e-01],
        [ 1.3183e+00],
        [ 6.9210e-01],
        [ 3.1166e-01],
        [ 1.4757e-03],
        [ 1.9216e+00],
        [ 1.1871e-01],
        [ 3.9865e-01],
        [-2.1237e-01],
        [-3.4822e-01],
        [ 3.5430e-01],
        [-8.3735e-01],
        [ 1.9773e-01],
        [ 2.0905e-01],
        [ 6.6396e-02],
        [ 2.7827e-01],
        [-8

In [30]:
# gt = torch.randint(low=0, high=2, size=(64,))
gt = batch['labels'].unsqueeze(1)
print(gt)

tensor([[0],
        [1],
        [1],
        [1],
        [1],
        [0],
        [1],
        [1],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [1],
        [0],
        [1],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0]])


In [31]:
loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, gt)

RuntimeError: result type Float can't be cast to the desired output type Long