In [13]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms,models
from torch.utils.data import DataLoader


In [14]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])


In [15]:
train_data = datasets.ImageFolder(
    root=r"chestctscan\Data\train",
    transform=transform
)

test_data = datasets.ImageFolder(
    root=r"chestctscan\Data\test",
    transform=transform
)


In [16]:
train_loader = DataLoader(
    train_data, batch_size=8, shuffle=True, num_workers=0, pin_memory=True
)
test_loader = DataLoader(
    test_data, batch_size=8, num_workers=0, pin_memory=True
)

In [17]:
class Chest(nn.Module):
  def __init__(self):
    super().__init__()
    self.base = models.resnet18(pretrained=True)
    self.base.fc = nn.Linear(self.base.fc.in_features,4)
  def forward(self,x):
    return self.base(x)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Chest().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

In [19]:
!pip install tqdm


In [20]:
from tqdm import tqdm
def train():
  model.train()
  total_loss = 0
  for images,labels in tqdm(train_loader, desc="Training", leave=False):
    images,labels = images.to(device),labels.to(device)
    optimizer.zero_grad()
    output = model(images)
    loss = loss_fn(output,labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  return total_loss/len(train_loader)


In [21]:
def evaluate():
  model.eval()
  correct,total = 0,0
  with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating", leave=False):
      images,labels = images.to(device),labels.to(device)
      preds = torch.argmax(model(images),dim=1)
      correct += (preds==labels).sum().item()
      total += labels.size(0)
  return correct/total

In [23]:
for epoch in tqdm(range(1, 30), desc="Epochs"):
  avg_loss = train()
  acc = evaluate()
  print(f"Epoch: {epoch}, Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")

Epochs:   3%|▎         | 1/29 [00:08<04:07,  8.84s/it]

Epoch: 1, Loss: 0.3242, Accuracy: 0.8000


Epochs:   7%|▋         | 2/29 [00:17<03:51,  8.56s/it]

Epoch: 2, Loss: 0.2439, Accuracy: 0.7778


Epochs:  10%|█         | 3/29 [00:25<03:38,  8.42s/it]

Epoch: 3, Loss: 0.1714, Accuracy: 0.7968


Epochs:  14%|█▍        | 4/29 [00:33<03:28,  8.36s/it]

Epoch: 4, Loss: 0.1647, Accuracy: 0.7429


Epochs:  17%|█▋        | 5/29 [00:41<03:19,  8.32s/it]

Epoch: 5, Loss: 0.1571, Accuracy: 0.8825


Epochs:  21%|██        | 6/29 [00:50<03:10,  8.29s/it]

Epoch: 6, Loss: 0.0997, Accuracy: 0.8730


Epochs:  24%|██▍       | 7/29 [00:58<03:02,  8.29s/it]

Epoch: 7, Loss: 0.0965, Accuracy: 0.8413


Epochs:  28%|██▊       | 8/29 [01:06<02:53,  8.27s/it]

Epoch: 8, Loss: 0.1030, Accuracy: 0.8190


Epochs:  31%|███       | 9/29 [01:14<02:44,  8.23s/it]

Epoch: 9, Loss: 0.0752, Accuracy: 0.8127


Epochs:  34%|███▍      | 10/29 [01:23<02:36,  8.25s/it]

Epoch: 10, Loss: 0.0565, Accuracy: 0.8063


Epochs:  38%|███▊      | 11/29 [01:31<02:28,  8.23s/it]

Epoch: 11, Loss: 0.0376, Accuracy: 0.8857


Epochs:  41%|████▏     | 12/29 [01:39<02:19,  8.21s/it]

Epoch: 12, Loss: 0.0407, Accuracy: 0.8698


Epochs:  45%|████▍     | 13/29 [01:47<02:10,  8.16s/it]

Epoch: 13, Loss: 0.0516, Accuracy: 0.8825


Epochs:  48%|████▊     | 14/29 [01:55<02:02,  8.15s/it]

Epoch: 14, Loss: 0.0353, Accuracy: 0.8825


Epochs:  52%|█████▏    | 15/29 [02:03<01:53,  8.14s/it]

Epoch: 15, Loss: 0.0486, Accuracy: 0.7873


Epochs:  55%|█████▌    | 16/29 [02:11<01:45,  8.12s/it]

Epoch: 16, Loss: 0.1049, Accuracy: 0.8381


Epochs:  59%|█████▊    | 17/29 [02:19<01:37,  8.11s/it]

Epoch: 17, Loss: 0.0407, Accuracy: 0.9016


Epochs:  62%|██████▏   | 18/29 [02:28<01:29,  8.18s/it]

Epoch: 18, Loss: 0.0344, Accuracy: 0.8762


Epochs:  66%|██████▌   | 19/29 [02:36<01:21,  8.19s/it]

Epoch: 19, Loss: 0.0484, Accuracy: 0.8571


Epochs:  69%|██████▉   | 20/29 [02:44<01:13,  8.18s/it]

Epoch: 20, Loss: 0.0425, Accuracy: 0.9111


Epochs:  72%|███████▏  | 21/29 [02:52<01:05,  8.16s/it]

Epoch: 21, Loss: 0.0491, Accuracy: 0.8667


Epochs:  76%|███████▌  | 22/29 [03:00<00:56,  8.12s/it]

Epoch: 22, Loss: 0.0733, Accuracy: 0.8286


Epochs:  79%|███████▉  | 23/29 [03:09<00:48,  8.17s/it]

Epoch: 23, Loss: 0.0331, Accuracy: 0.9079


Epochs:  83%|████████▎ | 24/29 [03:17<00:40,  8.18s/it]

Epoch: 24, Loss: 0.0261, Accuracy: 0.8413


Epochs:  86%|████████▌ | 25/29 [03:25<00:32,  8.16s/it]

Epoch: 25, Loss: 0.0334, Accuracy: 0.8286


Epochs:  90%|████████▉ | 26/29 [03:33<00:24,  8.18s/it]

Epoch: 26, Loss: 0.0268, Accuracy: 0.7810


Epochs:  93%|█████████▎| 27/29 [03:41<00:16,  8.19s/it]

Epoch: 27, Loss: 0.0372, Accuracy: 0.9365


Epochs:  97%|█████████▋| 28/29 [03:50<00:08,  8.23s/it]

Epoch: 28, Loss: 0.0389, Accuracy: 0.8413


Epochs: 100%|██████████| 29/29 [03:58<00:00,  8.22s/it]

Epoch: 29, Loss: 0.0358, Accuracy: 0.8889





In [24]:
torch.save(model.state_dict(), "chestmri.pth")

In [25]:
print("Model saved as chestmri.pth")

Model saved as chestmri.pth


<function list.count(value, /)>