In [2]:
import cv2
import torch
import torch.nn as nn

In [3]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((8, 8))  

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features= 128 * 8 * 8, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=4)
        
        
    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.global_avg_pool(self.pool3(torch.relu(self.conv3(x))))
                
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        
        return x      

In [4]:
def load_checkpoint(checkpoint, architecture, optimizer):
    print("loading checkpoint...")
    checkpoint = torch.load(checkpoint)
    
    model = architecture()
    
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    model = model.to('cuda')
    return model.eval()

In [5]:
model = CNNModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model_path = r"D:\Praharsha\code\CAMZ\models\model_history\1.0-CNN\CNN_checkpoint.pth.tar"  # Change to your model path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_checkpoint(model_path, CNNModel, optimizer )
model.eval()

loading checkpoint...


  checkpoint = torch.load(checkpoint)


CNNModel(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (global_avg_pool): AdaptiveAvgPool2d(output_size=(8, 8))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=8192, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=4, bias=True)
)

In [6]:
def yolo_to_pixel_coords(box, img_width, img_height):
    """Convert YOLO (x_center, y_center, width, height) to pixel coordinates."""
    x_center, y_center, w, h = box
    x1 = int((x_center - w / 2) * img_width)
    y1 = int((y_center - h / 2) * img_height)
    x2 = int((x_center + w / 2) * img_width)
    y2 = int((y_center + h / 2) * img_height)
    return x1, y1, x2, y2

In [7]:
def visualize_video_output(video_path, output_path, model):
    cap = cv2.VideoCapture(video_path)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        resized_frame = cv2.resize(frame, (224, 224))
        img_tensor = torch.from_numpy(resized_frame).permute(2, 0, 1).float().unsqueeze(0) / 255.0
        img_tensor = img_tensor.to(device)

        with torch.no_grad():
            preds = model(img_tensor)

        for pred in preds:
            x1, y1, x2, y2 = yolo_to_pixel_coords(pred[:4].tolist(), frame_width, frame_height)

            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
            # cv2.putText(frame, "Fish", (x1, y1 - 5),
            #             cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        out.write(frame) 

    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [14]:
vid_1 = r"C:\Users\USER\Downloads\Ntd3_ctrl8.mp4"
vid_1_output = r"C:\Users\USER\Downloads\Ntd3_ctrl8_out.mp4"
visualize_video_output(vid_1, vid_1_output, model)

In [10]:
vid_1 = r"D:\Zebra_Fish_Dataset\NTD_batch_III_videos\output_clips\ctrl13_30.mp4"
vid_1_output = r"D:\Zebra_Fish_Dataset\NTD_batch_III_videos\ctrl13_30_out.mp4"
visualize_video_output(vid_1, vid_1_output, model)

In [21]:
vid_2 = r"D:\Zebra Fish Dataset\novel tank cropped videos\novel tank cropped videos\3Rcrop.mp4"
vid_2_output = r"D:\Zebra Fish Dataset\novel tank cropped videos\novel tank cropped videos\3Rcrop_out.mp4"
visualize_video_output(vid_2, vid_2_output, model)

In [20]:
vid_3 = r"D:\Zebra Fish Dataset\novel tank cropped videos\novel tank cropped videos\Lcrop.mp4"
vid_3_output = r"D:\Zebra Fish Dataset\novel tank cropped videos\novel tank cropped videos\Lcrop_out.mp4"
visualize_video_output(vid_3, vid_3_output, model)