In [14]:
import torch
import torchvision
from torch import nn
from torchvision import transforms
import os
import PIL
from PIL import Image
from torch.utils.data import Dataset , Subset , DataLoader
import torchvision.models as models
import torch.optim as optim



In [15]:
transform = transforms.Compose([
    transforms.Resize((256 , 256)) , 
    transforms.PILToTensor() , 
    transforms.ConvertImageDtype(torch.float32)
    ])

In [16]:
data_dir = "/home/aman/code/CV/throat_infection/data"

classes = sorted(os.listdir(data_dir))

class_to_label = {class_name: label for label ,class_name in enumerate(classes)}




In [17]:
class CustomImageDataset(Dataset):

    def __init__(self , data_dir , transforms=None):
        self.data_dir = data_dir
        self.transform = transforms
        self.image_path = []
        self.labels = []

        for class_name , label in class_to_label.items():
            class_dir = os.path.join(data_dir , class_name)

            for image_path in os.listdir(class_dir):
                self.image_path.append(os.path.join(class_dir , image_path))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_path)
    
    def __getitem__(self ,idx):
        image_path = self.image_path[idx]
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        label = self.labels[idx]
        return image, label

                

In [18]:
dataset = CustomImageDataset(data_dir=data_dir , transforms=transform)

In [None]:
# testng layer


In [22]:
train_dataset = Subset(dataset , torch.arange(101))
valid_dataset = Subset(dataset , torch.arange(101 , len(dataset)))

device = torch.device("cuda")
batch_size = 5
train_dl = DataLoader(dataset , batch_size=batch_size , shuffle=True , pin_memory=True)
valid_dl = DataLoader(dataset , batch_size=batch_size , shuffle=True , pin_memory=True)


In [23]:
from torchvision.models import resnet50 , ResNet50_Weights


model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device=device)
in_feature = model.fc.in_features

model.fc = nn.Linear(in_feature , 1)

In [24]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters() , lr=0.0001)

In [25]:
def train(model, n_epoch, train_dl, valid_dl, use_cuda=True):
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    loss_hist_train = [0] * n_epoch
    loss_hist_valid = [0] * n_epoch
    acc_hist_valid = [0] * n_epoch
    acc_hist_train = [0] * n_epoch

    model.to(device)

    for epoch in range(n_epoch):
        model.train()
        for x, y in train_dl:
            x, y = x.to(device), y.to(device).float()  # Convert target labels to float
            pred = model(x).squeeze(1)  # Remove the extra dimension from the output
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_hist_train[epoch] += loss.item() * y.size(0)
            is_correct = (torch.round(torch.sigmoid(pred)) == y).float()  # Apply sigmoid and round
            acc_hist_train[epoch] += is_correct.sum()
        loss_hist_train[epoch] /= len(train_dl.dataset)
        acc_hist_train[epoch] /= len(train_dl.dataset)

        model.eval()
        with torch.no_grad():
            for x, y in valid_dl:
                x, y = x.to(device), y.to(device).float()  # Convert target labels to float
                pred = model(x).squeeze(1)  # Remove the extra dimension from the output
                loss = loss_fn(pred, y)
                loss_hist_valid[epoch] += loss.item() * y.size(0)
                is_correct = (torch.round(torch.sigmoid(pred)) == y).float()  # Apply sigmoid and round
                acc_hist_valid[epoch] += is_correct.sum()
            loss_hist_valid[epoch] /= len(valid_dl.dataset)
            acc_hist_valid[epoch] /= len(valid_dl.dataset)
        print(f"epoch {epoch+1} accuracy {acc_hist_train[epoch]:.4f} val_accuracy: {acc_hist_valid[epoch]:.4f}")

    return loss_hist_train, loss_hist_valid, acc_hist_train, acc_hist_valid

In [27]:
torch.manual_seed(1)

num_epoch = 30

hist = train(model , num_epoch , train_dl , valid_dl)



epoch 1 accuracy 0.4586 val_accuracy: 0.5055
epoch 2 accuracy 0.4862 val_accuracy: 0.5193
epoch 3 accuracy 0.4917 val_accuracy: 0.5331
epoch 4 accuracy 0.4807 val_accuracy: 0.5359
epoch 5 accuracy 0.5635 val_accuracy: 0.5552
epoch 6 accuracy 0.5221 val_accuracy: 0.5552
epoch 7 accuracy 0.5939 val_accuracy: 0.5552
epoch 8 accuracy 0.5635 val_accuracy: 0.5773
epoch 9 accuracy 0.5746 val_accuracy: 0.5856
epoch 10 accuracy 0.5718 val_accuracy: 0.5967
epoch 11 accuracy 0.6050 val_accuracy: 0.6022
epoch 12 accuracy 0.6105 val_accuracy: 0.5939
epoch 13 accuracy 0.5773 val_accuracy: 0.5994
epoch 14 accuracy 0.5663 val_accuracy: 0.6050
epoch 15 accuracy 0.6022 val_accuracy: 0.6133
epoch 16 accuracy 0.6077 val_accuracy: 0.6050
epoch 17 accuracy 0.5829 val_accuracy: 0.6050
epoch 18 accuracy 0.5994 val_accuracy: 0.5967
epoch 19 accuracy 0.5801 val_accuracy: 0.6188
epoch 20 accuracy 0.5994 val_accuracy: 0.6105
epoch 21 accuracy 0.5856 val_accuracy: 0.6022
epoch 22 accuracy 0.5967 val_accuracy: 0.59