In [1]:
import torch 
from torchvision import models,transforms,datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

In [2]:
resnet=models.resnet50(pretrained=True)



In [3]:
for param in resnet.parameters():
    param.requires_grad=False

In [4]:
num_features=resnet.fc.in_features
num_features

2048

In [5]:
resnet.fc=nn.Sequential(
    nn.Linear(2048,512),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.4),
    nn.Linear(512,4)
)

In [6]:
resnet=resnet.to("cuda")

In [7]:
x=torch.rand(1,3,244,244).to("cuda")
out=resnet(x)
out.shape

torch.Size([1, 4])

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=1e-3)

In [9]:
transformation=transforms.Compose([transforms.Resize(size=(244,244)),transforms.ToTensor()])

In [10]:
train_data=datasets.ImageFolder(root='/home/deepesh/intern_tumour/brain_tumour/Tumour/train',transform=transformation)
val_data=datasets.ImageFolder(root='/home/deepesh/intern_tumour/brain_tumour/Tumour/valid',transform=transformation)
test_data=datasets.ImageFolder(root='/home/deepesh/intern_tumour/brain_tumour/Tumour/test',transform=transformation)

In [11]:
train_loader=DataLoader(dataset=train_data,batch_size=32,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=32,shuffle=False)
test_loader=DataLoader(dataset=test_data,batch_size=32,shuffle=False)

In [12]:
epochs=15
for epoch in tqdm(range(epochs)):
    torch.cuda.empty_cache()
    train_loss=0
    val_loss=0
    val_correct=0
    val_total=0
    resnet.train()
    for imgs,labels in train_loader:
        imgs=imgs.to("cuda")
        labels=labels.to("cuda")
        out=resnet(imgs)
        loss=criterion(out,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss
    with torch.no_grad():
        resnet.eval()
        for imgs,labels in val_loader:
            imgs=imgs.to("cuda")
            labels=labels.to("cuda")
            out=resnet(imgs)
            loss=criterion(out,labels)
            pred=torch.argmax(out,1)
            val_correct+=(pred==labels).sum().item()
            val_total+=labels.size(0)
            val_loss+=loss
    val_accuracy=100*(val_correct/val_total)
    train_loss=100*(train_loss/len(train_loader))
    print(f"epoch:{epoch} trainloss:{train_loss} val_accuracy:{val_accuracy} val_loss:{val_loss}")
    
        

  7%|██▉                                         | 1/15 [00:16<03:49, 16.38s/it]

epoch:0 trainloss:90.28559112548828 val_accuracy:76.69322709163346 val_loss:10.090800285339355


 13%|█████▊                                      | 2/15 [00:32<03:29, 16.12s/it]

epoch:1 trainloss:52.03203582763672 val_accuracy:85.65737051792829 val_loss:6.220680236816406


 20%|████████▊                                   | 3/15 [00:48<03:13, 16.08s/it]

epoch:2 trainloss:50.05191421508789 val_accuracy:86.45418326693228 val_loss:5.814709663391113


 27%|███████████▋                                | 4/15 [01:04<02:57, 16.15s/it]

epoch:3 trainloss:36.561100006103516 val_accuracy:88.64541832669323 val_loss:5.4478230476379395


 33%|██████████████▋                             | 5/15 [01:20<02:41, 16.12s/it]

epoch:4 trainloss:33.45790100097656 val_accuracy:88.04780876494024 val_loss:5.408100605010986


 40%|█████████████████▌                          | 6/15 [01:36<02:24, 16.08s/it]

epoch:5 trainloss:32.810672760009766 val_accuracy:85.4581673306773 val_loss:6.721045970916748


 47%|████████████████████▌                       | 7/15 [01:52<02:08, 16.04s/it]

epoch:6 trainloss:31.115446090698242 val_accuracy:86.85258964143426 val_loss:5.778832912445068


 53%|███████████████████████▍                    | 8/15 [02:08<01:52, 16.01s/it]

epoch:7 trainloss:29.069185256958008 val_accuracy:88.44621513944223 val_loss:4.990807056427002


 60%|██████████████████████████▍                 | 9/15 [02:24<01:35, 15.97s/it]

epoch:8 trainloss:29.577245712280273 val_accuracy:89.44223107569721 val_loss:5.0740227699279785


 67%|████████████████████████████▋              | 10/15 [02:40<01:19, 15.97s/it]

epoch:9 trainloss:26.935636520385742 val_accuracy:88.64541832669323 val_loss:5.510217666625977


 73%|███████████████████████████████▌           | 11/15 [02:56<01:03, 15.97s/it]

epoch:10 trainloss:24.854042053222656 val_accuracy:90.0398406374502 val_loss:4.6307291984558105


 80%|██████████████████████████████████▍        | 12/15 [03:12<00:47, 15.99s/it]

epoch:11 trainloss:28.681787490844727 val_accuracy:86.65338645418326 val_loss:5.297514915466309


 87%|█████████████████████████████████████▎     | 13/15 [03:28<00:32, 16.07s/it]

epoch:12 trainloss:24.177305221557617 val_accuracy:89.2430278884462 val_loss:5.282051086425781


 93%|████████████████████████████████████████▏  | 14/15 [03:44<00:16, 16.07s/it]

epoch:13 trainloss:26.6910343170166 val_accuracy:82.86852589641434 val_loss:6.584343910217285


100%|███████████████████████████████████████████| 15/15 [04:00<00:00, 16.06s/it]

epoch:14 trainloss:25.529909133911133 val_accuracy:90.43824701195219 val_loss:4.373863220214844





In [13]:
def test():
    val_correct=0
    val_total=0
    for img,labels in test_loader:
        img=img.to("cuda")
        labels=labels.to("cuda")
        out=resnet(img)
        pred=torch.argmax(out,1)
        val_correct+=(pred==labels).sum().item()
        val_total+=labels.size(0)
    accuracy=100*(val_correct/val_total)
    print(f"the accuracy:{accuracy}")
        

In [14]:
test()

the accuracy:91.46341463414635
