## All Imports Here

In [2]:
from PIL import Image
import os
import math
import random
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import transforms
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
import torch.optim as optim
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
import wandb



## Log in to Wandb

In [None]:

wandb.login()

## Custom Classes for Loading Data and Model Architecture

In [None]:
class CustomPalmOilDataset(Dataset):
    def __init__(self, root_dir='andrews_directory', transform=None, target_transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        self.image_paths = []
        self.labels = []

        # Assuming 'Pure' corresponds to label 0 and 'Adulterated' corresponds to label 1
        label_mapping = {'Pure': 0, 'Adulterated': 1}

        for label_folder in os.listdir(root_dir):
            label = label_mapping.get(label_folder, -1)  # Assign -1 if folder not in mapping
            if label != -1:
                label_folder_path = os.path.join(root_dir, label_folder)
                for image_name in os.listdir(label_folder_path):
                    image_path = os.path.join(label_folder_path, image_name)
                    self.image_paths.append(image_path)
                    self.labels.append(label)


        combined_data = list(zip(self.image_paths, self.labels))
        random.shuffle(combined_data)
        self.image_paths, self.labels = zip(*combined_data)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label




class PalmOilClassifier(nn.Module):
    def __init__(self):
        super(PalmOilClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*53*53, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)  # Output has 2 units for binary classification

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # Flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # Output layer with 2 units for binary classification
        return x