In [7]:
from torch import nn, optim
import torch
from torchvision.models import resnet50, ResNet50_Weights
import os
import time
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
from torchvision.datasets import ImageFolder

In [8]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print('Running on MPS')
elif torch.cuda.is_available():
    os.environ["USE_FLASH_ATTENTION"] = "1"
    device = torch.device("cuda:0")
    torch.backends.cuda.matmul.allow_tf32 = True


    torch.backends.cudnn.allow_tf32 = True
    print('Running on CUDA')
else:
    device = torch.device("cpu")
    print('Running on CPU')

Running on CUDA


In [10]:

weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
model=resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, 4)
model = model.to(device)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters())

In [12]:

train_dataloader = torch.utils.data.DataLoader(
    ImageFolder("Brain_Tumor_Datasets_Classified\\train",transforms.Compose([
    transforms.Resize((224, 224)),   
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
])),
    batch_size=32, shuffle=True,pin_memory=True,num_workers=8)
test_dataloader = torch.utils.data.DataLoader(
    ImageFolder("Brain_Tumor_Datasets_Classified\\test",transforms.Compose([
    transforms.Resize((224, 224)),   
     transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
    batch_size=32, shuffle=True,pin_memory=True,num_workers=8)


In [15]:
#### Train model
train_loss = []
train_accuary = []
test_loss = []
test_accuary = []

num_epochs = 160
start_time = time.time()  
for epoch in range(num_epochs):  
    print("Epoch {} running".format(epoch))  
    model.train()  
    running_loss = 0.  
    running_corrects = 0
    # load a batch data of images
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data).item()
    epoch_loss = running_loss / len(train_dataloader.dataset)
    epoch_acc = running_corrects / len(train_dataloader.dataset) * 100.

    train_loss.append(epoch_loss)
    train_accuary.append(epoch_acc)
    print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch + 1, epoch_loss, epoch_acc,
                                                                       time.time() - start_time))

    if epoch % 5 == 0 or epoch == 0 or epoch==num_epochs:
        model.eval()
        with torch.no_grad():
            running_loss = 0.
            running_corrects = 0
            for inputs, labels in test_dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data).item()
            epoch_loss = running_loss / len(test_dataloader.dataset)
            epoch_acc = running_corrects / len(test_dataloader.dataset) * 100.
          
            test_loss.append(epoch_loss)
            test_accuary.append(epoch_acc)
       

        print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch + 1, epoch_loss, epoch_acc,
                                                                          time.time() - start_time))
        torch.save(model.state_dict(), f"ModelDumps\\resnet{epoch}.pth")

Epoch 0 running
[Train #1] Loss: 0.0016 Acc: 98.1887% Time: 39.4932s
[Test #1] Loss: 0.0017 Acc: 98.6135% Time: 60.9775s
Epoch 1 running
[Train #2] Loss: 0.0020 Acc: 97.5656% Time: 100.2456s
Epoch 2 running
[Train #3] Loss: 0.0018 Acc: 97.8264% Time: 139.6828s
Epoch 3 running
[Train #4] Loss: 0.0016 Acc: 97.9278% Time: 178.7647s
[Test #4] Loss: 0.0018 Acc: 98.7868% Time: 200.5402s
Epoch 4 running
[Train #5] Loss: 0.0014 Acc: 98.4495% Time: 239.7035s
Epoch 5 running
[Train #6] Loss: 0.0018 Acc: 97.8844% Time: 279.4456s
Epoch 6 running
[Train #7] Loss: 0.0017 Acc: 98.0438% Time: 318.5111s
[Test #7] Loss: 0.0013 Acc: 98.9601% Time: 339.8024s
Epoch 7 running
[Train #8] Loss: 0.0018 Acc: 97.9568% Time: 378.0259s
Epoch 8 running
[Train #9] Loss: 0.0018 Acc: 98.0003% Time: 416.0800s
Epoch 9 running
[Train #10] Loss: 0.0019 Acc: 97.8699% Time: 455.9475s
[Test #10] Loss: 0.0015 Acc: 98.8446% Time: 477.2941s
Epoch 10 running
[Train #11] Loss: 0.0017 Acc: 97.9568% Time: 515.9702s
Epoch 11 running

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x00000153CB001C60>
Traceback (most recent call last):
  File "C:\Users\MKP_Desktop\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\utils\data\dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "C:\Users\MKP_Desktop\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\utils\data\dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "C:\Users\MKP_Desktop\AppData\Local\Programs\Python\Python312\Lib\multiprocessing\process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\MKP_Desktop\AppData\Local\Programs\Python\Python312\Lib\multiprocessing\popen_spawn_win32.py", line 110, in wait
    res = _winapi.WaitForSingleObject(int(self._handle), msecs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt: 


KeyboardInterrupt: 