In [1]:
import torch 
from torchvision import models,transforms,datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

In [2]:
resnet=models.resnet50(pretrained=True)



In [3]:
for param in resnet.parameters():
    param.requires_grad=False

In [4]:
num_features=resnet.fc.in_features
num_features

2048

In [5]:
resnet.fc=nn.Sequential(
    nn.Linear(2048,512),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.4),
    nn.Linear(512,4)
)

In [6]:
resnet=resnet.to("cuda")

In [7]:
x=torch.rand(1,3,244,244).to("cuda")
out=resnet(x)
out.shape

torch.Size([1, 4])

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=1e-3)

In [9]:
transformation=transforms.Compose([transforms.Resize(size=(244,244)),transforms.ToTensor()])

In [10]:
train_data=datasets.ImageFolder(root='/home/deepesh/intern_tumour/brain_tumour/Tumour/train',transform=transformation)
val_data=datasets.ImageFolder(root='/home/deepesh/intern_tumour/brain_tumour/Tumour/valid',transform=transformation)
test_data=datasets.ImageFolder(root='/home/deepesh/intern_tumour/brain_tumour/Tumour/test',transform=transformation)

In [11]:
train_loader=DataLoader(dataset=train_data,batch_size=32,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=32,shuffle=False)
test_loader=DataLoader(dataset=test_data,batch_size=32,shuffle=False)

In [12]:
epochs=10
for epoch in tqdm(range(epochs)):
    torch.cuda.empty_cache()
    train_loss=0
    val_loss=0
    val_correct=0
    val_total=0
    resnet.train()
    for imgs,labels in train_loader:
        imgs=imgs.to("cuda")
        labels=labels.to("cuda")
        out=resnet(imgs)
        loss=criterion(out,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss
    with torch.no_grad():
        resnet.eval()
        for imgs,labels in val_loader:
            imgs=imgs.to("cuda")
            labels=labels.to("cuda")
            out=resnet(imgs)
            loss=criterion(out,labels)
            pred=torch.argmax(out,1)
            val_correct+=(pred==labels).sum().item()
            val_total+=labels.size(0)
            val_loss+=loss
    val_accuracy=100*(val_correct/val_total)
    train_loss=100*(train_loss/len(train_loader))
    print(f"epoch:{epoch} trainloss:{train_loss} val_accuracy:{val_accuracy} val_loss:{val_loss}")
    
        

 10%|████▍                                       | 1/10 [00:17<02:34, 17.14s/it]

epoch:0 trainloss:83.2774429321289 val_accuracy:80.67729083665338 val_loss:8.657512664794922


 20%|████████▊                                   | 2/10 [00:33<02:11, 16.45s/it]

epoch:1 trainloss:52.90311050415039 val_accuracy:86.05577689243027 val_loss:6.084044933319092


 30%|█████████████▏                              | 3/10 [00:49<01:53, 16.23s/it]

epoch:2 trainloss:41.92090606689453 val_accuracy:89.2430278884462 val_loss:5.351101398468018


 40%|█████████████████▌                          | 4/10 [01:05<01:36, 16.14s/it]

epoch:3 trainloss:33.787052154541016 val_accuracy:88.44621513944223 val_loss:4.992843151092529


 50%|██████████████████████                      | 5/10 [01:21<01:20, 16.12s/it]

epoch:4 trainloss:35.77688217163086 val_accuracy:88.84462151394422 val_loss:4.822647571563721


 60%|██████████████████████████▍                 | 6/10 [01:37<01:04, 16.12s/it]

epoch:5 trainloss:36.052886962890625 val_accuracy:83.06772908366534 val_loss:7.398009300231934


 70%|██████████████████████████████▊             | 7/10 [01:53<00:48, 16.06s/it]

epoch:6 trainloss:34.162879943847656 val_accuracy:89.44223107569721 val_loss:4.7895331382751465


 80%|███████████████████████████████████▏        | 8/10 [02:09<00:32, 16.10s/it]

epoch:7 trainloss:29.774110794067383 val_accuracy:89.04382470119522 val_loss:4.413508415222168


 90%|███████████████████████████████████████▌    | 9/10 [02:25<00:16, 16.24s/it]

epoch:8 trainloss:26.126724243164062 val_accuracy:89.2430278884462 val_loss:4.4195661544799805


100%|███████████████████████████████████████████| 10/10 [02:41<00:00, 16.20s/it]

epoch:9 trainloss:25.828338623046875 val_accuracy:87.45019920318725 val_loss:5.343942165374756





In [13]:
def test():
    val_correct=0
    val_total=0
    for img,labels in test_loader:
        img=img.to("cuda")
        labels=labels.to("cuda")
        out=resnet(img)
        pred=torch.argmax(out,1)
        val_correct+=(pred==labels).sum().item()
        val_total+=labels.size(0)
    accuracy=100*(val_correct/val_total)
    print(f"the accuracy:{accuracy}")
        

In [14]:
test()

the accuracy:89.43089430894308


In [15]:
torch.save(resnet,"resnet-192.pth")