In [None]:
# from dinov2.models.vision_transformer import vit_large
import os
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchmetrics
import pytorch_lightning as pl

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import pytorch_lightning as pl


class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths = []
        self.labels = []

        # Loop over the directory structure
        for label, class_name in enumerate(os.listdir(root_dir)):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                for file_name in os.listdir(class_dir):
                    file_path = os.path.join(class_dir, file_name)
                    if file_name.endswith(".png"):
                        self.file_paths.append(file_path)
                        self.labels.append(label)

        print(f"Loaded {len(self.file_paths)} images from {root_dir}")
        if len(self.file_paths) == 0:
            print(f"No images found in directory: {root_dir}")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image_path = self.file_paths[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label


class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            train_dir = os.path.join(self.data_dir, "train")
            print(f"Setting up training dataset from {train_dir}")
            train_dataset = ImageDataset(train_dir, transform=self.transform)
            num_train = len(train_dataset)
            if num_train == 0:
                raise ValueError(f"No training data found in {train_dir}")
            self.train_dataset, self.val_dataset = random_split(
                train_dataset, [num_train - 5, 5]
            )

        if stage == "test" or stage is None:
            test_dir = os.path.join(self.data_dir, "test")
            print(f"Setting up test dataset from {test_dir}")
            self.test_dataset = ImageDataset(test_dir, transform=self.transform)
            num_test = len(self.test_dataset)
            if num_test == 0:
                raise ValueError(f"No test data found in {test_dir}")

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)


class LinearClassifierHead(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.head = nn.Linear(embed_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.head(x))


import torch
import torchmetrics
import pytorch_lightning as pl
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import os
from PIL import Image

# Load the DINOv2 model
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")


class CustomModel(pl.LightningModule):
    def __init__(self, embed_dim, learning_rate):
        super().__init__()
        self.dinov2_vits14 = dinov2_vits14
        self.linear_classifier_head = LinearClassifierHead(embed_dim)
        self.criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
        self.learning_rate = learning_rate
        self.train_acc = torchmetrics.Accuracy(task="binary")
        self.val_acc = torchmetrics.Accuracy(task="binary")
        self.test_acc = torchmetrics.Accuracy(task="binary")

    def forward(self, x):
        with torch.no_grad():
            features = self.dinov2_vits14(x)
        return self.linear_classifier_head(features)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.linear_classifier_head.parameters(), lr=self.learning_rate
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images).squeeze()
        loss = self.criterion(outputs, labels.float())
        acc = self.train_acc(outputs, labels.int())
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images).squeeze()
        loss = self.criterion(outputs, labels.float())
        acc = self.val_acc(outputs, labels.int())
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        return loss

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images).squeeze()
        loss = self.criterion(outputs, labels.float())
        acc = self.test_acc(outputs, labels.int())
        self.log("test_loss", loss)
        self.log("test_acc", acc)
        return loss

In [None]:
# Example of initializing the DataModule and CustomModel for binary classification
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

data_module = DataModule(
    data_dir="/home/nayoonkim/dino_test/a_notebooks/IR",
    batch_size=8,
    transform=transform,
)
model = CustomModel(embed_dim=dinov2_vits14.embed_dim, learning_rate=1e-3)

trainer = pl.Trainer(max_epochs=20)

In [None]:
%%time
trainer.fit(model, data_module)

In [None]:
trainer.test(model, datamodule=data_module)

print("Training and testing complete.")

In [None]:
# # Configuration
# batch_size = 4
# num_epochs = 50
# learning_rate = 0.001
# transform = transforms.Compose(
#     [
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#     ]
# )

# data_module = DataModule(
#     data_dir="/home/nayoonkim/dino_test/a_notebooks/IR", batch_size=batch_size, transform=transform
# )
# model = CustomModel(embed_dim=dinov2_vits14.embed_dim, num_classes=3, learning_rate=learning_rate)
# trainer = pl.Trainer(max_epochs=num_epochs)

In [None]:
import matplotlib.pyplot as plt

# Visualize the accuracy
train_acc = model.train_acc.compute().cpu().numpy()
val_acc = model.val_acc.compute().cpu().numpy()
test_acc = model.test_acc.compute().cpu().numpy()

epochs = range(1, num_epochs + 1)

print(train_acc)
# plt.figure(figsize=(10, 5))
# plt.plot(epochs, train_acc, label='Training Accuracy')
# plt.plot(epochs, val_acc, label='Validation Accuracy')
# plt.xlabel('Epochs')
# plt.ylabel('Accuracy')
# plt.title('Training and Validation Accuracy')
# plt.legend()
# # plt.show()

# # print("Training and testing complete.")

# epochs = range(1, num_epochs + 1)

# plt.figure(figsize=(10, 5))
# plt.plot(epochs, train_acc_history, label='Training Accuracy')
# plt.plot(epochs, val_acc_history, label='Validation Accuracy')
# plt.xlabel('Epochs')
# plt.ylabel('Accuracy')
# plt.title('Training and Validation Accuracy')
# plt.legend()
# plt.show()

# print("Training and testing complete.")

In [None]:
# Visualize the accuracy
train_acc = trainer.callback_metrics["train_acc"].cpu().numpy()
val_acc = trainer.callback_metrics["val_acc"].cpu().numpy()
test_acc = trainer.callback_metrics["test_acc"].cpu().numpy()

epochs = range(1, num_epochs + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_acc, label="Training Accuracy")
plt.plot(epochs, val_acc, label="Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy")
plt.legend()
plt.show()

print("Training and testing complete.")