In [51]:
import torch

torch.__version__

'2.6.0'

In [70]:
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import requests
from io import BytesIO
from torchvision.models import resnet18, ResNet18_Weights

# Load class names
with open("class_names.txt", "r") as f:
    class_names = [line.strip() for line in f.readlines()]

def predict(img_path: str = None) -> str:
    # Initialize the model and transform
    resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
    resnet_transform = ResNet18_Weights.DEFAULT.transforms()

    # Load the image
    if img_path is None:
        image = Image.open("images/steak.jpeg").convert("RGB")
    elif img_path.startswith("http"):
        response = requests.get(img_path, stream=True)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(Path(img_path)).convert("RGB")

    # Convert to tensor
    img = torch.from_numpy(np.array(image)).permute(2, 0, 1)
    img = resnet_transform(img)

    # Inference
    resnet_model.eval()
    with torch.inference_mode():
        logits = resnet_model(img.unsqueeze(0))
        pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item()
        predicted_label = class_names[pred_class]
        print(f"Predicted class: {predicted_label}")
        return predicted_label


In [71]:
predict("https://m.media-amazon.com/images/I/41ypb39SsSL._AC_UF1000,1000_QL80_.jpg")

Predicted class: teddy


'teddy'

In [108]:
from safetensors.torch import save_file
import torch
from torch import nn
torch.manual_seed(42)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(in_features=3, out_features=1)
    
    def forward(self, x):
        return self.linear(x)



inp = torch.rand(size=(2,3))
print(inp)
model = MyModel()
model(inp)


tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009]])


tensor([[-0.1664],
        [-0.2550]], grad_fn=<AddmmBackward0>)

In [111]:
model.state_dict()

OrderedDict([('linear.weight', tensor([[-0.2811,  0.3391,  0.5090]])),
             ('linear.bias', tensor([-0.4236]))])

In [114]:
from safetensors.torch import save_file

save_file(model.state_dict(), "model.safetensors")

In [121]:
test_model  = MyModel()

test_model.state_dict()

OrderedDict([('linear.weight', tensor([[-0.1630, -0.3471,  0.0545]])),
             ('linear.bias', tensor([-0.5702]))])

In [123]:
from safetensors.torch import load_file

weights = load_file("./model.safetensors")
weights

{'linear.bias': tensor([-0.4236]),
 'linear.weight': tensor([[-0.2811,  0.3391,  0.5090]])}

In [127]:
test_model.load_state_dict(weights)

<All keys matched successfully>

In [129]:
test_model.state_dict()

OrderedDict([('linear.weight', tensor([[-0.2811,  0.3391,  0.5090]])),
             ('linear.bias', tensor([-0.4236]))])