In [19]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms,models
from torch.utils.data import DataLoader


In [20]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [21]:
train_data = datasets.ImageFolder(
    root=r"alzheimer\Combined Dataset\train",
    transform=transform
)

test_data = datasets.ImageFolder(
    root=r"alzheimer\Combined Dataset\test",
    transform=transform
)


In [22]:
train_loader = DataLoader(
    train_data, batch_size=8, shuffle=True, num_workers=0, pin_memory=True
)
test_loader = DataLoader(
    test_data, batch_size=8, num_workers=0, pin_memory=True
)

In [23]:
class Alzheimer(nn.Module):
  def __init__(self):
    super().__init__()
    self.base = models.resnet18(pretrained=True)
    self.base.fc = nn.Linear(self.base.fc.in_features,4)
  def forward(self,x):
    return self.base(x)

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Alzheimer().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

In [25]:
!pip install tqdm




In [26]:
from tqdm import tqdm
def train():
  model.train()
  total_loss = 0
  for images,labels in tqdm(train_loader, desc="Training", leave=False):
    images,labels = images.to(device),labels.to(device)
    optimizer.zero_grad()
    output = model(images)
    loss = loss_fn(output,labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  return total_loss/len(train_loader)


In [27]:
def evaluate():
  model.eval()
  correct,total = 0,0
  with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating", leave=False):
      images,labels = images.to(device),labels.to(device)
      preds = torch.argmax(model(images),dim=1)
      correct += (preds==labels).sum().item()
      total += labels.size(0)
  return correct/total

In [28]:
for epoch in tqdm(range(1, 30), desc="Epochs"):
  avg_loss = train()
  acc = evaluate()
  print(f"Epoch: {epoch}, Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")

Epochs:   3%|▎         | 1/29 [02:48<1:18:24, 168.01s/it]

Epoch: 1, Loss: 0.5707, Accuracy: 0.6693


Epochs:   7%|▋         | 2/29 [04:37<1:00:02, 133.43s/it]

Epoch: 2, Loss: 0.3090, Accuracy: 0.7224


Epochs:  10%|█         | 3/29 [06:30<53:46, 124.08s/it]  

Epoch: 3, Loss: 0.1926, Accuracy: 0.8554


Epochs:  14%|█▍        | 4/29 [08:18<49:10, 118.02s/it]

Epoch: 4, Loss: 0.1137, Accuracy: 0.9296


Epochs:  17%|█▋        | 5/29 [10:07<45:51, 114.63s/it]

Epoch: 5, Loss: 0.0822, Accuracy: 0.9038


Epochs:  21%|██        | 6/29 [12:03<44:07, 115.12s/it]

Epoch: 6, Loss: 0.0628, Accuracy: 0.9468


Epochs:  24%|██▍       | 7/29 [13:51<41:21, 112.80s/it]

Epoch: 7, Loss: 0.0585, Accuracy: 0.9476


Epochs:  28%|██▊       | 8/29 [15:39<38:57, 111.32s/it]

Epoch: 8, Loss: 0.0416, Accuracy: 0.9609


Epochs:  31%|███       | 9/29 [17:27<36:42, 110.14s/it]

Epoch: 9, Loss: 0.0477, Accuracy: 0.9679


Epochs:  34%|███▍      | 10/29 [19:16<34:48, 109.90s/it]

Epoch: 10, Loss: 0.0385, Accuracy: 0.9703


Epochs:  38%|███▊      | 11/29 [21:06<32:58, 109.94s/it]

Epoch: 11, Loss: 0.0277, Accuracy: 0.9797


Epochs:  41%|████▏     | 12/29 [22:55<31:01, 109.49s/it]

Epoch: 12, Loss: 0.0313, Accuracy: 0.9758


Epochs:  45%|████▍     | 13/29 [24:43<29:05, 109.09s/it]

Epoch: 13, Loss: 0.0275, Accuracy: 0.9797


Epochs:  48%|████▊     | 14/29 [26:31<27:13, 108.87s/it]

Epoch: 14, Loss: 0.0323, Accuracy: 0.9765


Epochs:  52%|█████▏    | 15/29 [28:20<25:25, 108.99s/it]

Epoch: 15, Loss: 0.0171, Accuracy: 0.9789


Epochs:  55%|█████▌    | 16/29 [30:14<23:55, 110.45s/it]

Epoch: 16, Loss: 0.0078, Accuracy: 0.9382


Epochs:  59%|█████▊    | 17/29 [32:21<23:05, 115.42s/it]

Epoch: 17, Loss: 0.0342, Accuracy: 0.9633


Epochs:  62%|██████▏   | 18/29 [34:12<20:54, 114.03s/it]

Epoch: 18, Loss: 0.0154, Accuracy: 0.9750


Epochs:  66%|██████▌   | 19/29 [36:24<19:54, 119.43s/it]

Epoch: 19, Loss: 0.0227, Accuracy: 0.9484


Epochs:  69%|██████▉   | 20/29 [38:12<17:23, 115.94s/it]

Epoch: 20, Loss: 0.0121, Accuracy: 0.9758


Epochs:  72%|███████▏  | 21/29 [40:00<15:09, 113.67s/it]

Epoch: 21, Loss: 0.0166, Accuracy: 0.9734


Epochs:  76%|███████▌  | 22/29 [41:48<13:04, 112.02s/it]

Epoch: 22, Loss: 0.0129, Accuracy: 0.9664


Epochs:  79%|███████▉  | 23/29 [43:38<11:08, 111.43s/it]

Epoch: 23, Loss: 0.0164, Accuracy: 0.9906


Epochs:  83%|████████▎ | 24/29 [45:35<09:25, 113.05s/it]

Epoch: 24, Loss: 0.0138, Accuracy: 0.9922


Epochs:  86%|████████▌ | 25/29 [47:24<07:26, 111.61s/it]

Epoch: 25, Loss: 0.0061, Accuracy: 0.9578


Epochs:  90%|████████▉ | 26/29 [49:13<05:32, 110.93s/it]

Epoch: 26, Loss: 0.0225, Accuracy: 0.9593


Epochs:  93%|█████████▎| 27/29 [51:51<04:09, 124.96s/it]

Epoch: 27, Loss: 0.0096, Accuracy: 0.9601


Epochs:  97%|█████████▋| 28/29 [54:18<02:11, 131.69s/it]

Epoch: 28, Loss: 0.0115, Accuracy: 0.9820


Epochs: 100%|██████████| 29/29 [56:40<00:00, 117.26s/it]

Epoch: 29, Loss: 0.0032, Accuracy: 0.9750





In [29]:
torch.save(model.state_dict(), "Alzehimer.pth")

In [30]:
print("Model saved as Alzheimer.pth")

Model saved as Alzheimer.pth
