## test CNN pytorch

In [None]:
#Packages
! pip install requests
! pip install json
! pip install matplotib.pyplot
! pip install time
! pip install numpy 
! pip install yfinance
! pip install pandas
! pip install torch

In [48]:
#Packages
import requests
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import yfinance as yf 
import pandas as pd
import torch.nn as nn
import torch
import numpy as np

In [None]:
## Test CNN takes 256 x 256 image
class Torch_CNN(nn.Module):
    def __init__(self, embed, num_pred):
        super(Torch_CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=embed, out_channels=8, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(16 * 64 * 64, num_pred)

    def forward(self, x): 
        x = torch.relu(self.conv1(x))  
        x = self.pool(x)           
        x = torch.relu(self.conv2(x))  
        x = self.pool(x)           
        x = x.view(x.size(0), -1)
        x = self.fc1(x)            
        return x


## Test FNN takes 4 ts values of shape 4 x 1 and CNN output
class Torch_FFN(nn.Module):
    def __init__(self, embed, hidden, num_pred):
        super(Torch_FFN, self).__init__()  
        self.fc1 = nn.Linear(embed, hidden) 
        self.fc2 = nn.Linear(hidden, num_pred)  
    
    def forward(self, x):  
        x = torch.relu(self.fc1(x))  
        x = self.fc2(x)            
        return x


# test Ens 
class Torch_Ens(nn.Module):
    def __init__(self, embed_img, embed_ts, hidden_dim, num_pred):
        super().__init__()
        self.cnn = Torch_CNN(embed_img, hidden_dim)  
        self.fnn = Torch_FFN(hidden_dim + embed_ts, hidden_dim, num_pred)
    
    def forward(self, x_img, y_ts):  # CORRECTED: forward (not foward)
        x = self.cnn(x_img)
        combined = torch.cat([x, y_ts], dim=1)  # On combine 
        x = self.fnn(combined)
        return x


In [56]:
#Test
## El Model
model = Torch_Ens(
    embed_img=3,     
    embed_ts=4,      
    hidden_dim=256,  
    num_pred=1        
)

batch_size = 12
x_img = torch.randn(batch_size, 3, 256, 256)  
x_ts = torch.randn(batch_size, 4)             

print(f"Image shape: {x_img.shape}")
print(f"Time series shape: {x_ts.shape}")

output = model(x_img, x_ts)
print(f"out shape: {output.shape}")  # Should be (12, 1)



Image shape: torch.Size([12, 3, 256, 256])
Time series shape: torch.Size([12, 4])
out shape: torch.Size([12, 1])


In [None]:
## crée un dataset 
from torch.utils.data import Dataset, DataLoader

Tickers = ["ZC=F", "ZM=F", "GF=F", "LE=F"] #Corn / Soybean / Feeder cattle / live cattle
data_price = []
for ticker in Tickers:
    data = yf.download(ticker,"2025-02-01", progress = False, auto_adjust=False, interval="1mo") #1year of data
    data_price.append(data['Adj Close'].pct_change().dropna())
data_price

df = pd.concat(data_price, axis=1)
df.columns = ["Corn","Soy","Cow_food","Cows"] 

df.head()

def Get_image_data(df):
    dates =  df.index.strftime('%Y-%m-%d').tolist() #On change le format

    params = {
    "SERVICE": "WMS",
    "REQUEST": "GetMap",
    "VERSION": "1.3.0",
    "LAYERS":  "MODIS_Terra_CorrectedReflectance_TrueColor", #Basic image
    "CRS": "EPSG:4326",
    "BBOX": "21,-128,49,-59",
    "WIDTH": 256, # On baisse la résolution 
    "HEIGHT": 256,
    "FORMAT": "image/png",
    "TIME": "01-01-2020", 
    "TRANSPARENT": "FALSE"""
    }

    images_dict={}
    for i, date in enumerate(dates):
        if i > 0:
            time.sleep(2)  #On est gentil avec l'API de la Nasa
        params['TIME'] = date 
        print(f"{params['TIME']}") 
        try:
            API_endpoint = "https://gibs.earthdata.nasa.gov/wms/epsg4326/best/wms.cgi"
            response = requests.get(API_endpoint, params=params, timeout=30)
            
            if response.status_code == 200:  # Success
                if response.content.startswith(b'\x89PNG'):  
                    img = mpimg.imread(io.BytesIO(response.content))
                    images_dict[date] = img
                else:
                    print(f"Response not PNG for {date}") 
                    print(response.text[:200]) 
                    
            else:
                print(f"Error {response.status_code} for {date}")  
    
        except Exception as e:  
            print(f"Erreur pour {date}: {str(e)}")  
    return images_dict

images_dict = Get_image_data(df)