In [1]:
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch

import numpy as np

In [2]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [None]:
def image_cleaner(root_dir, save_dir):
    from PIL import Image
    import os

    image_paths = os.listdir(root_dir)
    print(len(os.listdir(save_dir)))

    for i, image_path in enumerate(image_paths):
        if not os.path.exists(f"{save_dir}/{image_path}"):
            try:
                img = Image.open(f"{root_dir}/{image_path}")
                img = img.convert("RGB")
                img.save(f"{save_dir}/{image_path}")

                print(f"{i} out of {len(image_paths)}")
            except Exception as e:
                print(f"{image_path} FAILED {e}")

image_cleaner(
    root_dir="drive/MyDrive/kaggle/machine_unlearning/data",
    save_dir="drive/MyDrive/kaggle/machine_unlearning/data_clean"
)

In [3]:
class Model(nn.Module):
    def __init__(self, nr_classes):
        from torchvision.models import resnet50, ResNet50_Weights

        super(Model, self).__init__()

        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.nr_features = self.model.fc.in_features
        self.model.fc = nn.Linear(self.nr_features, nr_classes)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.model(x)
        logits = self.softmax(x)

        return logits

In [14]:
class UTKFace(Dataset):
    def __init__(self, root_dir, device):
        from torchvision.models import ResNet50_Weights
        import os

        self.root_dir = root_dir
        self.device = device

        self.image_paths = random.sample(os.listdir(self.root_dir), 100)

        self.image_labels = torch.Tensor(
            [int(x.split("_")[0])//10 for x in self.image_paths]
        ).to(torch.int64)

        self.num_classes = int(self.image_labels.max()) + 1

        self.image_ohe = nn.functional.one_hot(
            self.image_labels, num_classes=self.num_classes
        ).to(torch.float32).to(self.device)

        self.image_transform = ResNet50_Weights.DEFAULT.transforms()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        from torchvision.io import read_image

        img = read_image(f"{self.root_dir}/{self.image_paths[idx]}")
        img_tensor = self.image_transform(img).to(self.device)
        label = self.image_ohe[idx]

        return img_tensor, label

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
dataset = UTKFace(
    root_dir="drive/MyDrive/kaggle/machine_unlearning/data_clean",
    device=device
)

train_size = int(len(dataset)*0.7)
validation_size = int(len(dataset)*0.1)
test_size = len(dataset) - train_size - validation_size

train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(
    dataset,
    [train_size, validation_size, test_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [16]:
model = Model(
    nr_classes=dataset.num_classes,
).to(device)

In [17]:
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [19]:
n_epochs = 100

for epoch in range(n_epochs):
    for mode in ["train"]:#, "validation"]:
        if mode == "train":
            model.train()
            dataloader = train_dataloader
        elif mode == "validation":
            model.eval()
            dataloader = validation_dataloader

        print(f"MODE {mode}", flush=True)

        total_loss = 0.
        total_correct = 0.
        for batch, data in enumerate(dataloader):
            optimizer.zero_grad()

            inputs, labels = data
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            correct = (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()

            total_loss += loss.item()
            total_correct += correct

            if mode == "train":
                loss.backward()
                optimizer.step()

            if batch % 1 == 0:
                print(f"BATCH {batch} LOSS {loss.item()} ACCURACY: {correct/len(inputs)}")

        print(f"MODE {mode} EPOCH {epoch} AVG LOSS {total_loss/len(dataloader)} AVG ACCURACY: {total_correct/train_size}")
        print("="*100)



MODE train
BATCH 0 LOSS 2.3053383827209473 ACCURACY: 0.09375
BATCH 1 LOSS 2.271394729614258 ACCURACY: 0.5
MODE train EPOCH 0 AVG LOSS 2.2883665561676025 AVG ACCURACY: 0.12857142857142856
MODE train
BATCH 0 LOSS 2.1061134338378906 ACCURACY: 0.84375
BATCH 1 LOSS 2.1005048751831055 ACCURACY: 0.16666666666666666
MODE train EPOCH 1 AVG LOSS 2.103309154510498 AVG ACCURACY: 0.7857142857142857
MODE train
BATCH 0 LOSS 1.8721628189086914 ACCURACY: 0.921875
BATCH 1 LOSS 1.8956153392791748 ACCURACY: 0.6666666666666666
MODE train EPOCH 2 AVG LOSS 1.883889079093933 AVG ACCURACY: 0.9
MODE train
BATCH 0 LOSS 1.702397108078003 ACCURACY: 0.953125
BATCH 1 LOSS 1.8769481182098389 ACCURACY: 0.6666666666666666
MODE train EPOCH 3 AVG LOSS 1.789672613143921 AVG ACCURACY: 0.9285714285714286
MODE train
BATCH 0 LOSS 1.58231782913208 ACCURACY: 1.0
BATCH 1 LOSS 1.7283586263656616 ACCURACY: 1.0
MODE train EPOCH 4 AVG LOSS 1.6553382277488708 AVG ACCURACY: 1.0
MODE train
BATCH 0 LOSS 1.5363870859146118 ACCURACY: 1.0


KeyboardInterrupt: 