In [153]:
from transformers import AutoTokenizer
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import Compose, ColorJitter, ToTensor, Resize
from torchvision.models import ResNet50_Weights
from torch.utils.data import DataLoader
from datasets import load_dataset,ClassLabel,Value,concatenate_datasets,Features,Array2D

import numpy as np

In [7]:

dataset_diffusion_train = load_dataset("poloclub/diffusiondb","2m_first_10k")
dataset_diffusion_test = load_dataset("poloclub/diffusiondb","2m_random_5k")
dataset_imagenette = load_dataset("frgfm/imagenette","160px")

Found cached dataset diffusiondb (/home/hakim/.cache/huggingface/datasets/poloclub___diffusiondb/2m_first_10k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1)
100%|██████████| 1/1 [00:00<00:00, 220.03it/s]
Found cached dataset diffusiondb (/home/hakim/.cache/huggingface/datasets/poloclub___diffusiondb/2m_random_5k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1)
100%|██████████| 1/1 [00:00<00:00, 673.68it/s]
Found cached dataset imagenette (/home/hakim/.cache/huggingface/datasets/frgfm___imagenette/160px/1.0.0/38929285b8abcae5c1305418e9d8fea5dd6b189bbbd22caba5f5537c7fa0f01f)
100%|██████████| 2/2 [00:00<00:00, 714.41it/s]


In [156]:
train_db = dataset_diffusion_train["train"].select(range(9469))
val_db = dataset_diffusion_test["train"].select(range(1962))
test_db = dataset_diffusion_test["train"].select(range(1962,3925))

In [157]:
train_I = dataset_imagenette["train"]
val_I = dataset_imagenette["validation"].select(range(1962))
test_I = dataset_imagenette["validation"].select(range(1962,3925)) 

In [158]:
label_column = len(train_db) * [0]
train_db = train_db.add_column("label",label_column)
label_column = len(val_db) * [0]
val_db = val_db.add_column("label",label_column)
label_column = len(test_db) * [0]
test_db = test_db.add_column("label",label_column)

In [159]:
del_col = ['prompt','seed', 'step','cfg','sampler','width','height','user_name','timestamp','image_nsfw', 'prompt_nsfw']
train_db = train_db.remove_columns(del_col)
val_db = val_db.remove_columns(del_col)
test_db = test_db.remove_columns(del_col)

In [161]:
train_I = train_I.remove_columns("label")
val_I = val_I.remove_columns("label")
test_I = test_I.remove_columns("label")

In [162]:
label_column = len(train_I) * [1]
train_I = train_I.add_column("label",label_column)
label_column = len(val_I) * [1]
val_I = val_I.add_column("label",label_column)
label_column = len(test_I) * [1]
test_I = test_I.add_column("label",label_column)

In [163]:
train = concatenate_datasets([train_I,train_db])
val = concatenate_datasets([val_I,val_db])
test = concatenate_datasets([test_I,test_db])

In [164]:
jitter = Compose(
    [ ToTensor(),Resize((160,160))]
)

In [165]:
def transforms(examples):

    examples["pixel_values"] = [jitter(image.convert("RGB")) for image in examples["image"]]

    return examples

In [166]:
train_dataset = train.with_transform(transforms)
val_dataset = val.with_transform(transforms)
test_dataset = train.with_transform(transforms)

In [167]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 0.001

In [168]:
def collate_fn(examples):

    images = []

    labels = []

    for example in examples:

        images.append((example["pixel_values"]))

        labels.append(example["label"])

        

    pixel_values = torch.stack(images)

    labels = torch.tensor(labels)

    return {"pixel_values": pixel_values, "label": labels}

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size)
val_dataloader =  DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size)

In [169]:
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

In [170]:
model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [171]:
tr = model.train()

In [172]:
train_accu = []
train_losses = []

def train(epoch,data):
    print('\nEpoch : %d'%epoch)
    correct = 0
    running_loss=0
    total=0
    for element in data:
        # Move input and label tensors to the device
        inputs = element["pixel_values"].to(device)
        labels = element["label"].to(device)

        # Zero out the optimizer
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss=running_loss/len(data)
    accu=100.*correct/total

    train_accu.append(accu)
    train_losses.append(train_loss)
    print('Train Loss: %.3f | Accuracy: %.3f'%(train_loss,accu))


In [173]:
eval_losses=[]
eval_accu=[]

def val(data):
    correct = 0
    running_loss=0
    total=0
    with torch.no_grad():
      for element in data:
          # Move input and label tensors to the device
          inputs = element["pixel_values"].to(device)
          labels = element["label"].to(device)

        
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          running_loss += loss.item()

          _, predicted = outputs.max(1)
          total += labels.size(0)
          correct += predicted.eq(labels).sum().item()

    val_loss=running_loss/len(data)
    accu=100.*correct/total

    eval_accu.append(accu)
    eval_losses.append(val_loss)
    print('Val Loss: %.3f | Accuracy: %.3f'%(val_loss,accu))


In [174]:
epochs=10
for epoch in range(1,epochs+1):
  train(epoch,train_dataloader)
  val(val_dataloader)


Epoch : 1




Train Loss: 0.201 | Accuracy: 97.661
Val Loss: 6.928 | Accuracy: 50.000

Epoch : 2
Train Loss: 0.240 | Accuracy: 96.420
Val Loss: 6.439 | Accuracy: 50.000

Epoch : 3
