In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [7]:
# Load dataset
dataset = datasets.ImageFolder('imageset', data_transforms['train'])
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)

dataloaders = {'train': train_loader, 'val': val_loader}
dataset_sizes = {'train': len(train_dataset), 'val': len(test_dataset)}
class_names = dataset.classes

In [9]:
from torchvision import models
import torch.nn as nn

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)

In [10]:
from torch import optim
import time
import copy

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train() 
            else:
                model.eval()  

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)

Epoch 0/24
----------
train Loss: 0.6161 Acc: 0.7391
val Loss: 0.7830 Acc: 0.4167

Epoch 1/24
----------
train Loss: 0.6298 Acc: 0.7174
val Loss: 0.9692 Acc: 0.5833

Epoch 2/24
----------
train Loss: 0.4672 Acc: 0.7609
val Loss: 0.3768 Acc: 0.7500

Epoch 3/24
----------
train Loss: 0.2892 Acc: 0.9130
val Loss: 0.2087 Acc: 0.9167

Epoch 4/24
----------
train Loss: 0.2182 Acc: 0.9348
val Loss: 0.5464 Acc: 0.9167

Epoch 5/24
----------
train Loss: 0.5380 Acc: 0.8261
val Loss: 0.5835 Acc: 0.6667

Epoch 6/24
----------
train Loss: 0.2201 Acc: 0.8696
val Loss: 0.6265 Acc: 0.6667

Epoch 7/24
----------
train Loss: 0.5570 Acc: 0.8043
val Loss: 0.2924 Acc: 0.8333

Epoch 8/24
----------
train Loss: 0.2745 Acc: 0.8261
val Loss: 0.1746 Acc: 0.9167

Epoch 9/24
----------
train Loss: 0.2331 Acc: 0.8913
val Loss: 0.4495 Acc: 0.8333

Epoch 10/24
----------
train Loss: 0.2042 Acc: 0.9130
val Loss: 0.2283 Acc: 0.9167

Epoch 11/24
----------
train Loss: 0.3774 Acc: 0.8696
val Loss: 0.4318 Acc: 0.8333

Ep

In [15]:
import cv2
from PIL import Image
import torch
import time

cap = cv2.VideoCapture('testvideo.mp4')
frame_rate = cap.get(cv2.CAP_PROP_FPS) 

detection_pause = 5
frames_to_skip = detection_pause * frame_rate
skip_frames_counter = 0 

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    if skip_frames_counter == 0:
        frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        frame_transformed = data_transforms['val'](frame_pil).unsqueeze(0).to(device)
        
        outputs = model_ft(frame_transformed)
        _, preds = torch.max(outputs, 1)

        if preds[0] == 1: 
            current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000 
            print(f"Shrimp detected at {current_time:.2f} seconds.")
            skip_frames_counter = frames_to_skip 

    else:
        skip_frames_counter -= 1 

cap.release()


Shrimp detected at 6.43 seconds.
Shrimp detected at 11.48 seconds.
Shrimp detected at 16.52 seconds.
Shrimp detected at 21.57 seconds.
Shrimp detected at 26.61 seconds.
Shrimp detected at 31.65 seconds.
Shrimp detected at 36.70 seconds.
Shrimp detected at 41.74 seconds.
Shrimp detected at 46.78 seconds.
Shrimp detected at 51.83 seconds.
Shrimp detected at 56.87 seconds.
Shrimp detected at 61.91 seconds.
Shrimp detected at 66.96 seconds.
Shrimp detected at 72.00 seconds.


In [16]:
torch.save(model_ft.state_dict(), 'model_state_dict.pth')

In [22]:
from PIL import Image
import torch

image_path = 'shrimp.png' 
image = Image.open(image_path).convert('RGB')  

image_transformed = data_transforms['val'](image).unsqueeze(0).to(device)

model_ft.eval() 
outputs = model_ft(image_transformed)
_, preds = torch.max(outputs, 1)

if preds[0] == 1:  
    print("Shrimp detected in the image.")
else:
    print("No shrimp detected in the image.")


Shrimp detected in the image.
